Official code for "Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021".

Overview

Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021.

Introduction

We proposed a novel model training paradigm for few-shot semantic segmentation. Instead of meta-learning the whole, complex segmentation model, we focus on the simplest classifier part to make new-class adaptation more tractable. Also, a novel meta-learning algorithm that leverages a Classifier Weight Transformer (CWT) for adapting dynamically the classifier weights to every query sample is introduced to eliminate the impact of intra-class discripency.

Architecture

Environment

Other configurations can also work, but the results may be slightly different.

  • torch==1.6.0
  • numpy==1.19.1
  • cv2==4.4.0
  • pyyaml==5.3.1

Dataset

We follow the same rule to download and process dataset as that in https://github.com/Jia-Research-Lab/PFENet. After processing, please change the "data_root" and "train/val_list" in config files accordingly.

Pre-trained models in the first stage

For convenience, we provide the pre-trained models on base classes for each split. Download it here: https://drive.google.com/file/d/1yHUNI1iTwF5U_HqCQ4kF6ti8lepcrBBY/view?usp=sharing, and change "resume_weights" to this folder.

Episodic training and inference

  • The general training script
sh scripts/train.sh {data} {split} {[gpu_ids]} {layers} {shots}
  • This is an example with 1-shot, ResNet-50, split-0 on PASCAL and GPU device [0].
sh scripts/train.sh pascal 0 [0] 50 1
  • Inference script
sh scripts/test.sh {data} {shot} {[gpu_ids]} {layers} {split}

Contact

Please write down issues or contact me via zhihe.lu [at] surrey.ac.uk if you have any questions.

Citation

If you feel helpful of this work, please cite it. Will update this when it is officially published on ICCV.

@misc{lu2021simpler,
      title={Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer}, 
      author={Zhihe lu and Sen He and Xiatian Zhu and Li Zhang and Yi-Zhe Song and Tao Xiang},
      year={2021},
      eprint={2108.03032},
      archivePrefix={arXiv},
      primaryClass={cs.CV}
}

Acknowledgments

Thanks to the code contributors. Some parts of code are borrowed from https://github.com/Jia-Research-Lab/PFENet and https://github.com/mboudiaf/RePRI-for-Few-Shot-Segmentation.

Comments
  • Script for Segmentation Map Visualization

    Script for Segmentation Map Visualization

    If you want to visualize segmentation results at the end of each test iteration, please use the following.

    Insert following to the main test.py script (below classes.append())

                    logits_q[i] = pred_q.detach()
                    gt_q[i, 0] = q_label
                    classes.append([class_.item() for class_ in subcls])
                    # Insert visualization routine here 
                    if args.visualize:
                        output = {}
                        output['query'], output['support'] = {}, {}
                        output['query']['gt'], output['query']['pred'] =     vis_res(qry_oris[0][0],      qry_oris[1], F.interpolate(pred_q, size=q_label.size()[1:], mode='bilinear', align_corners=True).squeeze().detach().cpu().numpy())
                        spprt_label = torch.cat(spprt_oris[1], 0)
                        output['support']['gt'], output['support']['pred'] = vis_res(spprt_oris[0][0][0],spprt_label, output_support.squeeze().detach().cpu().numpy())
    
                        save_image = np.concatenate((output['support']['gt'], output['query']['gt'], output['query']['pred']), 1)
                        cv2.imwrite('./analysis/' + qry_oris[0][0].split('/')[-1] ,   save_image)
    

    Main visualization function vis_res is the following:

    def resize_image_label(image, label, size = 473):
        import cv2
        def find_new_hw(ori_h, ori_w, test_size):
            if ori_h >= ori_w:
                ratio = test_size * 1.0 / ori_h
                new_h = test_size
                new_w = int(ori_w * ratio)
            elif ori_w > ori_h:
                ratio = test_size * 1.0 / ori_w
                new_h = int(ori_h * ratio)
                new_w = test_size
    
            if new_h % 8 != 0:
                new_h = (int(new_h / 8)) * 8
            else:
                new_h = new_h
            if new_w % 8 != 0:
                new_w = (int(new_w / 8)) * 8
            else:
                new_w = new_w
            return new_h, new_w
    
        # Step 1: resize while keeping the h/w ratio. The largest side (i.e height or width) is reduced to $size.
        #                                             The other is reduced accordingly
        test_size = size
        new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size)
    
        image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)),
                                interpolation=cv2.INTER_LINEAR)
    
        # Step 2: Pad wtih 0 whatever needs to be padded to get a ($size, $size) image
        back_crop = np.zeros((test_size, test_size, 3))
    
        back_crop[:new_h, :new_w, :] = image_crop
        image = back_crop
    
        # Step 3: Do the same for the label (the padding is 255)
        s_mask = label
        new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size)
        s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),
                            interpolation=cv2.INTER_NEAREST)
        back_crop_s_mask = np.ones((test_size, test_size)) * 255
        back_crop_s_mask[:new_h, :new_w] = s_mask
        label = back_crop_s_mask
    
        return image, label
    
    def vis_res(image_path, label, pred):
    
        import cv2
        def read_image(path):
            image = cv2.imread(path, cv2.IMREAD_COLOR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = np.float32(image)
            return image
    
        def label_to_image(label):
            label = label == 1.
            label = np.float32(label) * 255.
            placeholder = np.zeros_like(label)
            label = np.concatenate((label, placeholder), 0)
            label = np.concatenate((label, placeholder), 0)
            label = np.transpose(label, (1,2,0))
            return label
    
        def blend_image_label(image, label):
            result = 0.5 * image + 0.5 * label
            result = np.float32(result)
            result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
    
            return result
    
        def pred_to_image(label):
            label = np.float32(label) * 255.
            placeholder = np.zeros_like(label)
            placeholder = np.concatenate((placeholder, placeholder), 0)
            label = np.concatenate((placeholder, label), 0)
            label = np.transpose(label, (1,2,0))
            return label
    
        image = read_image(image_path)
        label = label.squeeze().detach().cpu().numpy()
        image, label = resize_image_label(image, label)
        label = label_to_image(np.expand_dims(label, 0))
        out_image_gt = blend_image_label(image, label)
        #cv2.imwrite('./analysis/' + image_path.split('/')[-1][:-4] +  '_gt.jpg',   out_image)
    
        pred  = np.argmax(pred, 0)
        pred = np.expand_dims(pred, 0)
        pred = pred_to_image(pred)
        out_image_pred = blend_image_label(image, pred)
        #cv2.imwrite('./analysis/' + image_path.split('/')[-1][:-4] +  '_pred.jpg',   out_image)
    
        return out_image_gt, out_image_pred
    
    opened by kilickaya 3
  • Error run training model

    Error run training model

    I run model for training but I got this problem: [ WARN:[email protected]] global /io/opencv/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('COCO/train/COCO_train2014_000000494112.png'): can't open/read file: check file path/integrity

    Do you have a way to solve it? Please help me.

    opened by viethoang303 1
  • pre-trained model

    pre-trained model

    Thanks for your released code. However, in your paper and code, there are no detailed description about the training of pre-trained model. Are you choose the task of classification or segmentation for pre-training process? If segmentation, how can you construct the image-label pairs?

    opened by LIUYUANWEI98 1
  • how to generate SegmentationClassAug?

    how to generate SegmentationClassAug?

    I have downloaded the PASCAL VOC 2012 dataset and SBD dataset from their official website respectively.But I donnot know how to use them in the code .Could you show me your dataset folder structure or the method to generate SegmentationClassAug? Thanks , bubble from hfut!

    opened by WHL182 1
  • Pascal SegmentationClassAug dataset

    Pascal SegmentationClassAug dataset

    Thank you for your work! I have a question about Pascal dataset. I have downloaded the official dataset of Pascal. But i didn't find the SegmentationClassAug dataset which was used in the code. How can I get this dataset? Thanks!

    opened by tt622 1
  • The pre-trained model

    The pre-trained model

    Hi,

    Thanks for your nice job! I have used your default parameter to train and test on the val dataset of voc2007, but i got the 15% miou. I have read your code, it seems like that because of the pre-trained model, like resnet50,resnet101,or the transformer model. In the code, it is the "resume_weights: /pretrained_models/"

    Would you please give me some suggestions?

    Best wishes to you

    opened by zwy1996 0
  • Can not achieve the results in the article

    Can not achieve the results in the article

    When I finished the training, I found that I could not achieve the results in the paper, especially on the COCO dataset, can you tell me more details on the training?

    opened by 7M7L 0
  • COCO dataset!!! Can not load coco dataset for mask image

    COCO dataset!!! Can not load coco dataset for mask image

    When I read the mask image, I found that the category corresponding to the mask obtained by reading does not correspond to the category of the coco dataset. How can I solve this problem? Thanks

    image

    opened by chenhao-zju 0
  • Difference between transformer_resnet50 and pspnet_resnet50

    Difference between transformer_resnet50 and pspnet_resnet50

    Hi, I am trying to reproduce the results in your paper. By the way, I encounter some difficulties.

    I I downloaded coco dataset and then indicate the folder on data_root in coco.yaml. I tried to launch a training, ending up with a transformer_resnet50 model. Testing results was poor, like 0.13 in mIOU with 5 shots, s0.

    I downloaded the pre-trained modeld. The folder contains pspnet_resnet50 models. I changed the resume_weights in coco.yaml. If I try to run the test with these models, I got the error:

    sh scripts/test.sh coco 5 [2] 50 0
    ==> Running DDP checkpoint example on rank 0.
    => no weight found at '/home/bacchin/CWT/CWT_venv/CWT-for-FSS/model_ckpt/coco/split=0/model/shot_5/transformer_resnet50/best.pth'
    Traceback (most recent call last):
      File "/usr/lib/python3.6/runpy.py", line 193, in _run_module_as_main
        "__main__", mod_spec)
      File "/usr/lib/python3.6/runpy.py", line 85, in _run_code
        exec(code, run_globals)
      File "/home/bacchin/CWT/CWT_venv/CWT-for-FSS/src/test.py", line 302, in <module>
        mp.spawn(main_worker, args=(world_size, args), nprocs=world_size, join=True)
      File "/home/bacchin/CWT/CWT_venv/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
        return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
      File "/home/bacchin/CWT/CWT_venv/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
        while not context.join():
      File "/home/bacchin/CWT/CWT_venv/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 150, in join
        raise ProcessRaisedException(msg, error_index, failed_process.pid)
    torch.multiprocessing.spawn.ProcessRaisedException: 
    
    -- Process 0 terminated with the following error:
    Traceback (most recent call last):
      File "/home/bacchin/CWT/CWT_venv/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
        fn(i, *args)
      File "/home/bacchin/CWT/CWT_venv/CWT-for-FSS/src/test.py", line 93, in main_worker
        assert os.path.isfile(filepath), filepath
    AssertionError: model_ckpt/coco/split=0/model/shot_5/transformer_resnet50/best.pth
    

    I understood that model_dir in coco.yaml must point to a transformer_resnet50 model. I tried to put in model_dir, the transformer_resnet50 obtained with the training. It worked, but results are still under the performance declared in the paper (like 0.3 with 5 shots, s0)

    transformer_resnet50 are not delivered with pre-trained files. Why? And why we have to pointers to models, namely model_dir and resume_weights?

    Am I missing something?

    Thank you for your help!

    opened by bach05 1
  • RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

    RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

    When I was training on a single machine with multiple GPUs, I encountered the following error. What is the reason?

    (base) lixiang@vs008:~/CWT-for-FSS$ sh scripts/train.sh pascal 0 [0,1] 50 1
      0%|                                                                                       | 0/5953 [00:00<?, ?it/s]==> Running process rank 0.
    FB_param_noise: 0
    adapt_iter: 200
    arch: resnet
    augmentations: ['hor_flip', 'vert_flip', 'resize']
    backbone_dim: 2048
    batch_size: 2
    batch_size_val: 2
    bins: [1, 2, 3, 6]
    bottleneck_dim: 512
    ckpt_path: checkpoints/
    ckpt_used: best
    cls_lr: 0.1
    data_root: pascal/
    debug: False
    distributed: True
    dropout: 0.1
    episodic: True
    epochs: 20
    gamma: 0.1
    gpus: [0, 1]
    heads: 4
    image_size: 473
    iter_per_epoch: 6000
    layers: 50
    log_freq: 50
    lr_stepsize: 30
    m_scale: False
    main_optim: SGD
    manual_seed: 2021
    mean: [0.485, 0.456, 0.406]
    milestones: [40, 70]
    mixup: False
    model_dir: model_ckpt
    momentum: 0.9
    n_runs: 1
    nesterov: True
    norm_feat: True
    num_classes_tr: 2
    num_classes_val: 5
    padding_label: 255
    port: 53765
    pretrained: False
    random_shot: False
    resume_weights: /pretrained_models/
    rot_max: 10
    rot_min: -10
    save_models: True
    save_oracle: False
    scale_lr: 1.0
    scale_max: 2.0
    scale_min: 0.5
    scheduler: cosine
    shot: 1
    smoothing: True
    std: [0.229, 0.224, 0.225]
    test_name: default
    test_num: 1000
    test_split: default
    train_list: lists/pascal/train.txt
    train_name: pascal
    train_split: 0
    trans_lr: 0.001
    use_split_coco: False
    val_list: lists/pascal/val.txt
    weight_decay: 0.0001
    workers: 2
    => no weight found at '/pretrained_models/'
    Processing data for [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
      0%|                                                                                       | 0/5953 [00:00<?, ?it/s]==> Running process rank 1.
    FB_param_noise: 0
    adapt_iter: 200
    arch: resnet
    augmentations: ['hor_flip', 'vert_flip', 'resize']
    backbone_dim: 2048
    batch_size: 2
    batch_size_val: 2
    bins: [1, 2, 3, 6]
    bottleneck_dim: 512
    ckpt_path: checkpoints/
    ckpt_used: best
    cls_lr: 0.1
    data_root: pascal/
    debug: False
    distributed: True
    dropout: 0.1
    episodic: True
    epochs: 20
    gamma: 0.1
    gpus: [0, 1]
    heads: 4
    image_size: 473
    iter_per_epoch: 6000
    layers: 50
    log_freq: 50
    lr_stepsize: 30
    m_scale: False
    main_optim: SGD
    manual_seed: 2021
    mean: [0.485, 0.456, 0.406]
    milestones: [40, 70]
    mixup: False
    model_dir: model_ckpt
    momentum: 0.9
    n_runs: 1
    nesterov: True
    norm_feat: True
    num_classes_tr: 2
    num_classes_val: 5
    padding_label: 255
    port: 53765
    pretrained: False
    random_shot: False
    resume_weights: /pretrained_models/
    rot_max: 10
    rot_min: -10
    save_models: True
    save_oracle: False
    scale_lr: 1.0
    scale_max: 2.0
    scale_min: 0.5
    scheduler: cosine
    shot: 1
    smoothing: True
    std: [0.229, 0.224, 0.225]
    test_name: default
    test_num: 1000
    test_split: default
    train_list: lists/pascal/train.txt
    train_name: pascal
    train_split: 0
    trans_lr: 0.001
    use_split_coco: False
    val_list: lists/pascal/val.txt
    weight_decay: 0.0001
    workers: 2
    => no weight found at '/pretrained_models/'
    Processing data for [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
    100%|███████████████████████████████████████████████████████████████████████████| 5953/5953 [00:08<00:00, 681.75it/s]
    100%|███████████████████████████████████████████████████████████████████████████| 5953/5953 [00:09<00:00, 609.93it/s]
      0%|                                                                                       | 0/1449 [00:00<?, ?it/s]INFO: pascal -> pascal
    INFO: 0 -> 0
    >> Start Filtering classes 
    >> Removed classes = [] 
    >> Kept classes = ['airplane', 'bicycle', 'bird', 'boat', 'bottle'] 
    Processing data for [1, 2, 3, 4, 5]
      0%|                                                                                       | 0/1449 [00:00<?, ?it/s]INFO: pascal -> pascal
    INFO: 0 -> 0
    >> Start Filtering classes 
    >> Removed classes = [] 
    >> Kept classes = ['airplane', 'bicycle', 'bird', 'boat', 'bottle'] 
    Processing data for [1, 2, 3, 4, 5]
    100%|███████████████████████████████████████████████████████████████████████████| 1449/1449 [00:06<00:00, 229.57it/s]
    100%|███████████████████████████████████████████████████████████████████████████| 1449/1449 [00:05<00:00, 241.58it/s]
    Traceback (most recent call last):
      File "/home/lixiang/anaconda3/lib/python3.8/runpy.py", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/lixiang/anaconda3/lib/python3.8/runpy.py", line 87, in _run_code
        exec(code, run_globals)
      File "/home/lixiang/CWT-for-FSS/src/train.py", line 360, in <module>
        mp.spawn(main_worker, args=(world_size, args), nprocs=world_size, join=True)
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 230, in spawn
        return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
        while not context.join():
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 150, in join
        raise ProcessRaisedException(msg, error_index, failed_process.pid)
    torch.multiprocessing.spawn.ProcessRaisedException: 
    
    -- Process 1 terminated with the following error:
    Traceback (most recent call last):
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
        fn(i, *args)
      File "/home/lixiang/CWT-for-FSS/src/train.py", line 134, in main_worker
        _, _ = do_epoch(
      File "/home/lixiang/CWT-for-FSS/src/train.py", line 266, in do_epoch
        output_support = binary_cls(f_s)
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 446, in forward
        return self._conv_forward(input, self.weight, self.bias)
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 442, in _conv_forward
        return F.conv2d(input, weight, bias, self.stride,
    RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)
    
    
    opened by lixiang007666 1
  • Error while training the resnet model using pascal dataset

    Error while training the resnet model using pascal dataset

    Traceback (most recent call last):
      File "/usr/lib/python3.7/runpy.py", line 193, in _run_module_as_main
        "__main__", mod_spec)
      File "/usr/lib/python3.7/runpy.py", line 85, in _run_code
        exec(code, run_globals)
      File "/content/CWT-for-FSS/src/train.py", line 360, in <module>
        mp.spawn(main_worker, args=(world_size, args), nprocs=world_size, join=True)
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py", line 230, in spawn
        return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
        while not context.join():
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py", line 150, in join
        raise ProcessRaisedException(msg, error_index, failed_process.pid)
    torch.multiprocessing.spawn.ProcessRaisedException: 
    
    -- Process 0 terminated with the following error:
    multiprocessing.pool.RemoteTraceback: 
    """
    Traceback (most recent call last):
      File "/usr/lib/python3.7/multiprocessing/pool.py", line 121, in worker
        result = (True, func(*args, **kwds))
      File "/usr/lib/python3.7/multiprocessing/pool.py", line 44, in mapstar
        return list(map(*args))
      File "/content/CWT-for-FSS/src/dataset/utils.py", line 91, in process_image
        assert label_class_ in list(range(1, 81)), label_class_
    AssertionError: 147
    """
    
    The above exception was the direct cause of the following exception:
    
    Traceback (most recent call last):
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
        fn(i, *args)
      File "/content/CWT-for-FSS/src/train.py", line 117, in main_worker
        train_loader, train_sampler = get_train_loader(args)
      File "/content/CWT-for-FSS/src/dataset/dataset.py", line 44, in get_train_loader
        mode_train=True, transform=train_transform, class_list=class_list, args=args
      File "/content/CWT-for-FSS/src/dataset/dataset.py", line 114, in __init__
        self.data_list, self.sub_class_file_list = make_dataset(args.data_root, args.train_list, self.class_list)
      File "/content/CWT-for-FSS/src/dataset/utils.py", line 55, in make_dataset
        for sublist, subdict in mmap_(process_partial, tqdm(list_read)):
      File "/content/CWT-for-FSS/src/dataset/utils.py", line 17, in mmap_
        return Pool().map(fn, iter)
      File "/usr/lib/python3.7/multiprocessing/pool.py", line 268, in map
        return self._map_async(func, iterable, mapstar, chunksize).get()
      File "/usr/lib/python3.7/multiprocessing/pool.py", line 657, in get
        raise self._value
    AssertionError: 147
    
    opened by Hemanth-Gattu 11
Owner
Lucas
A PhD student on Computer Vision.
Lucas
Official TensorFlow code for the forthcoming paper

~ Efficient-CapsNet ~ Are you tired of over inflated and overused convolutional neural networks? You're right! It's time for CAPSULES :)

Vittorio Mazzia 203 Jan 8, 2023
Official code for Score-Based Generative Modeling through Stochastic Differential Equations

Score-Based Generative Modeling through Stochastic Differential Equations This repo contains the official implementation for the paper Score-Based Gen

Yang Song 818 Jan 6, 2023
Official code for paper "Optimization for Oriented Object Detection via Representation Invariance Loss".

Optimization for Oriented Object Detection via Representation Invariance Loss By Qi Ming, Zhiqiang Zhou, Lingjuan Miao, Xue Yang, and Yunpeng Dong. Th

ming71 56 Nov 28, 2022
This repo provides the official code for TransBTS: Multimodal Brain Tumor Segmentation Using Transformer (https://arxiv.org/pdf/2103.04430.pdf).

TransBTS: Multimodal Brain Tumor Segmentation Using Transformer This repo is the official implementation for TransBTS: Multimodal Brain Tumor Segmenta

Raymond 247 Dec 28, 2022
Official code of the paper "ReDet: A Rotation-equivariant Detector for Aerial Object Detection" (CVPR 2021)

ReDet: A Rotation-equivariant Detector for Aerial Object Detection ReDet: A Rotation-equivariant Detector for Aerial Object Detection (CVPR2021), Jiam

csuhan 334 Dec 23, 2022
Official code implementation for "Personalized Federated Learning using Hypernetworks"

Personalized Federated Learning using Hypernetworks This is an official implementation of Personalized Federated Learning using Hypernetworks paper. [

Aviv Shamsian 121 Dec 25, 2022
Official code for the paper: Deep Graph Matching under Quadratic Constraint (CVPR 2021)

QC-DGM This is the official PyTorch implementation and models for our CVPR 2021 paper: Deep Graph Matching under Quadratic Constraint. It also contain

Quankai Gao 55 Nov 14, 2022
Official code for the ICLR 2021 paper Neural ODE Processes

Neural ODE Processes Official code for the paper Neural ODE Processes (ICLR 2021). Abstract Neural Ordinary Differential Equations (NODEs) use a neura

Cristian Bodnar 50 Oct 28, 2022
Official PyTorch Code of GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection (CVPR 2021)

GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Monocular 3D Object Detection GrooMeD-NMS: Grouped Mathematically Differentiable NMS for Mo

Abhinav Kumar 76 Jan 2, 2023
Official code for the CVPR 2021 paper "How Well Do Self-Supervised Models Transfer?"

How Well Do Self-Supervised Models Transfer? This repository hosts the code for the experiments in the CVPR 2021 paper How Well Do Self-Supervised Mod

Linus Ericsson 157 Dec 16, 2022
Official PyTorch code of Holistic 3D Scene Understanding from a Single Image with Implicit Representation (CVPR 2021)

Implicit3DUnderstanding (Im3D) [Project Page] Holistic 3D Scene Understanding from a Single Image with Implicit Representation Cheng Zhang, Zhaopeng C

Cheng Zhang 149 Jan 8, 2023
This is the official code release for the paper Shape and Material Capture at Home

This is the official code release for the paper Shape and Material Capture at Home. The code enables you to reconstruct a 3D mesh and Cook-Torrance BRDF from one or more images captured with a flashlight or camera with flash.

null 89 Dec 10, 2022
Official code of CVPR 2021's PLOP: Learning without Forgetting for Continual Semantic Segmentation

PLOP: Learning without Forgetting for Continual Semantic Segmentation This repository contains all of our code. It is a modified version of Cermelli e

Arthur Douillard 116 Dec 14, 2022
Official code of our work, Unified Pre-training for Program Understanding and Generation [NAACL 2021].

PLBART Code pre-release of our work, Unified Pre-training for Program Understanding and Generation accepted at NAACL 2021. Note. A detailed documentat

Wasi Ahmad 138 Dec 30, 2022
official code for dynamic convolution decomposition

Revisiting Dynamic Convolution via Matrix Decomposition (ICLR 2021) A pytorch implementation of DCD. If you use this code in your research please cons

Yunsheng Li 110 Nov 23, 2022
This repo contains the official code of our work SAM-SLR which won the CVPR 2021 Challenge on Large Scale Signer Independent Isolated Sign Language Recognition.

Skeleton Aware Multi-modal Sign Language Recognition By Songyao Jiang, Bin Sun, Lichen Wang, Yue Bai, Kunpeng Li and Yun Fu. Smile Lab @ Northeastern

Isen (Songyao Jiang) 128 Dec 8, 2022
Official code for "End-to-End Optimization of Scene Layout" -- including VAE, Diff Render, SPADE for colorization (CVPR 2020 Oral)

End-to-End Optimization of Scene Layout Code release for: End-to-End Optimization of Scene Layout CVPR 2020 (Oral) Project site, Bibtex For help conta

Andrew Luo 41 Dec 9, 2022
Official source code to CVPR'20 paper, "When2com: Multi-Agent Perception via Communication Graph Grouping"

When2com: Multi-Agent Perception via Communication Graph Grouping This is the PyTorch implementation of our paper: When2com: Multi-Agent Perception vi

null 34 Nov 9, 2022
Official code repository of the paper Learning Associative Inference Using Fast Weight Memory by Schlag et al.

Learning Associative Inference Using Fast Weight Memory This repository contains the offical code for the paper Learning Associative Inference Using F

Imanol Schlag 18 Oct 12, 2022