Official code for "Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021".


Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer. ICCV2021.


We proposed a novel model training paradigm for few-shot semantic segmentation. Instead of meta-learning the whole, complex segmentation model, we focus on the simplest classifier part to make new-class adaptation more tractable. Also, a novel meta-learning algorithm that leverages a Classifier Weight Transformer (CWT) for adapting dynamically the classifier weights to every query sample is introduced to eliminate the impact of intra-class discripency.



Other configurations can also work, but the results may be slightly different.

  • torch==1.6.0
  • numpy==1.19.1
  • cv2==4.4.0
  • pyyaml==5.3.1


We follow the same rule to download and process dataset as that in After processing, please change the "data_root" and "train/val_list" in config files accordingly.

Pre-trained models in the first stage

For convenience, we provide the pre-trained models on base classes for each split. Download it here:, and change "resume_weights" to this folder.

Episodic training and inference

  • The general training script
sh scripts/ {data} {split} {[gpu_ids]} {layers} {shots}
  • This is an example with 1-shot, ResNet-50, split-0 on PASCAL and GPU device [0].
sh scripts/ pascal 0 [0] 50 1
  • Inference script
sh scripts/ {data} {shot} {[gpu_ids]} {layers} {split}


Please write down issues or contact me via [at] if you have any questions.


If you feel helpful of this work, please cite it. Will update this when it is officially published on ICCV.

      title={Simpler is Better: Few-shot Semantic Segmentation with Classifier Weight Transformer}, 
      author={Zhihe lu and Sen He and Xiatian Zhu and Li Zhang and Yi-Zhe Song and Tao Xiang},


Thanks to the code contributors. Some parts of code are borrowed from and

  • Script for Segmentation Map Visualization

    Script for Segmentation Map Visualization

    If you want to visualize segmentation results at the end of each test iteration, please use the following.

    Insert following to the main script (below classes.append())

                    logits_q[i] = pred_q.detach()
                    gt_q[i, 0] = q_label
                    classes.append([class_.item() for class_ in subcls])
                    # Insert visualization routine here 
                    if args.visualize:
                        output = {}
                        output['query'], output['support'] = {}, {}
                        output['query']['gt'], output['query']['pred'] =     vis_res(qry_oris[0][0],      qry_oris[1], F.interpolate(pred_q, size=q_label.size()[1:], mode='bilinear', align_corners=True).squeeze().detach().cpu().numpy())
                        spprt_label =[1], 0)
                        output['support']['gt'], output['support']['pred'] = vis_res(spprt_oris[0][0][0],spprt_label, output_support.squeeze().detach().cpu().numpy())
                        save_image = np.concatenate((output['support']['gt'], output['query']['gt'], output['query']['pred']), 1)
                        cv2.imwrite('./analysis/' + qry_oris[0][0].split('/')[-1] ,   save_image)

    Main visualization function vis_res is the following:

    def resize_image_label(image, label, size = 473):
        import cv2
        def find_new_hw(ori_h, ori_w, test_size):
            if ori_h >= ori_w:
                ratio = test_size * 1.0 / ori_h
                new_h = test_size
                new_w = int(ori_w * ratio)
            elif ori_w > ori_h:
                ratio = test_size * 1.0 / ori_w
                new_h = int(ori_h * ratio)
                new_w = test_size
            if new_h % 8 != 0:
                new_h = (int(new_h / 8)) * 8
                new_h = new_h
            if new_w % 8 != 0:
                new_w = (int(new_w / 8)) * 8
                new_w = new_w
            return new_h, new_w
        # Step 1: resize while keeping the h/w ratio. The largest side (i.e height or width) is reduced to $size.
        #                                             The other is reduced accordingly
        test_size = size
        new_h, new_w = find_new_hw(image.shape[0], image.shape[1], test_size)
        image_crop = cv2.resize(image, dsize=(int(new_w), int(new_h)),
        # Step 2: Pad wtih 0 whatever needs to be padded to get a ($size, $size) image
        back_crop = np.zeros((test_size, test_size, 3))
        back_crop[:new_h, :new_w, :] = image_crop
        image = back_crop
        # Step 3: Do the same for the label (the padding is 255)
        s_mask = label
        new_h, new_w = find_new_hw(s_mask.shape[0], s_mask.shape[1], test_size)
        s_mask = cv2.resize(s_mask.astype(np.float32), dsize=(int(new_w), int(new_h)),
        back_crop_s_mask = np.ones((test_size, test_size)) * 255
        back_crop_s_mask[:new_h, :new_w] = s_mask
        label = back_crop_s_mask
        return image, label
    def vis_res(image_path, label, pred):
        import cv2
        def read_image(path):
            image = cv2.imread(path, cv2.IMREAD_COLOR)
            image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
            image = np.float32(image)
            return image
        def label_to_image(label):
            label = label == 1.
            label = np.float32(label) * 255.
            placeholder = np.zeros_like(label)
            label = np.concatenate((label, placeholder), 0)
            label = np.concatenate((label, placeholder), 0)
            label = np.transpose(label, (1,2,0))
            return label
        def blend_image_label(image, label):
            result = 0.5 * image + 0.5 * label
            result = np.float32(result)
            result = cv2.cvtColor(result, cv2.COLOR_BGR2RGB)
            return result
        def pred_to_image(label):
            label = np.float32(label) * 255.
            placeholder = np.zeros_like(label)
            placeholder = np.concatenate((placeholder, placeholder), 0)
            label = np.concatenate((placeholder, label), 0)
            label = np.transpose(label, (1,2,0))
            return label
        image = read_image(image_path)
        label = label.squeeze().detach().cpu().numpy()
        image, label = resize_image_label(image, label)
        label = label_to_image(np.expand_dims(label, 0))
        out_image_gt = blend_image_label(image, label)
        #cv2.imwrite('./analysis/' + image_path.split('/')[-1][:-4] +  '_gt.jpg',   out_image)
        pred  = np.argmax(pred, 0)
        pred = np.expand_dims(pred, 0)
        pred = pred_to_image(pred)
        out_image_pred = blend_image_label(image, pred)
        #cv2.imwrite('./analysis/' + image_path.split('/')[-1][:-4] +  '_pred.jpg',   out_image)
        return out_image_gt, out_image_pred
  • Error run training model

    Error run training model

    I run model for training but I got this problem: [ WARN:[email protected]] global /io/opencv/modules/imgcodecs/src/loadsave.cpp (239) findDecoder imread_('COCO/train/COCO_train2014_000000494112.png'): can't open/read file: check file path/integrity

    Do you have a way to solve it? Please help me.

    opened by viethoang303 1
  • pre-trained model

    pre-trained model

    Thanks for your released code. However, in your paper and code, there are no detailed description about the training of pre-trained model. Are you choose the task of classification or segmentation for pre-training process? If segmentation, how can you construct the image-label pairs?

    opened by LIUYUANWEI98 1
  • how to generate SegmentationClassAug?

    how to generate SegmentationClassAug?

    I have downloaded the PASCAL VOC 2012 dataset and SBD dataset from their official website respectively.But I donnot know how to use them in the code .Could you show me your dataset folder structure or the method to generate SegmentationClassAug? Thanks , bubble from hfut!

    opened by WHL182 1
  • Pascal SegmentationClassAug dataset

    Pascal SegmentationClassAug dataset

    Thank you for your work! I have a question about Pascal dataset. I have downloaded the official dataset of Pascal. But i didn't find the SegmentationClassAug dataset which was used in the code. How can I get this dataset? Thanks!

    opened by tt622 1
  • The pre-trained model

    The pre-trained model


    Thanks for your nice job! I have used your default parameter to train and test on the val dataset of voc2007, but i got the 15% miou. I have read your code, it seems like that because of the pre-trained model, like resnet50,resnet101,or the transformer model. In the code, it is the "resume_weights: /pretrained_models/"

    Would you please give me some suggestions?

    Best wishes to you

    opened by zwy1996 0
  • Can not achieve the results in the article

    Can not achieve the results in the article

    When I finished the training, I found that I could not achieve the results in the paper, especially on the COCO dataset, can you tell me more details on the training?

    opened by 7M7L 0
  • COCO dataset!!! Can not load coco dataset for mask image

    COCO dataset!!! Can not load coco dataset for mask image

    When I read the mask image, I found that the category corresponding to the mask obtained by reading does not correspond to the category of the coco dataset. How can I solve this problem? Thanks


    opened by chenhao-zju 0
  • Difference between transformer_resnet50 and pspnet_resnet50

    Difference between transformer_resnet50 and pspnet_resnet50

    Hi, I am trying to reproduce the results in your paper. By the way, I encounter some difficulties.

    I I downloaded coco dataset and then indicate the folder on data_root in coco.yaml. I tried to launch a training, ending up with a transformer_resnet50 model. Testing results was poor, like 0.13 in mIOU with 5 shots, s0.

    I downloaded the pre-trained modeld. The folder contains pspnet_resnet50 models. I changed the resume_weights in coco.yaml. If I try to run the test with these models, I got the error:

    sh scripts/ coco 5 [2] 50 0
    ==> Running DDP checkpoint example on rank 0.
    => no weight found at '/home/bacchin/CWT/CWT_venv/CWT-for-FSS/model_ckpt/coco/split=0/model/shot_5/transformer_resnet50/best.pth'
    Traceback (most recent call last):
      File "/usr/lib/python3.6/", line 193, in _run_module_as_main
        "__main__", mod_spec)
      File "/usr/lib/python3.6/", line 85, in _run_code
        exec(code, run_globals)
      File "/home/bacchin/CWT/CWT_venv/CWT-for-FSS/src/", line 302, in <module>
        mp.spawn(main_worker, args=(world_size, args), nprocs=world_size, join=True)
      File "/home/bacchin/CWT/CWT_venv/lib/python3.6/site-packages/torch/multiprocessing/", line 230, in spawn
        return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
      File "/home/bacchin/CWT/CWT_venv/lib/python3.6/site-packages/torch/multiprocessing/", line 188, in start_processes
        while not context.join():
      File "/home/bacchin/CWT/CWT_venv/lib/python3.6/site-packages/torch/multiprocessing/", line 150, in join
        raise ProcessRaisedException(msg, error_index,
    -- Process 0 terminated with the following error:
    Traceback (most recent call last):
      File "/home/bacchin/CWT/CWT_venv/lib/python3.6/site-packages/torch/multiprocessing/", line 59, in _wrap
        fn(i, *args)
      File "/home/bacchin/CWT/CWT_venv/CWT-for-FSS/src/", line 93, in main_worker
        assert os.path.isfile(filepath), filepath
    AssertionError: model_ckpt/coco/split=0/model/shot_5/transformer_resnet50/best.pth

    I understood that model_dir in coco.yaml must point to a transformer_resnet50 model. I tried to put in model_dir, the transformer_resnet50 obtained with the training. It worked, but results are still under the performance declared in the paper (like 0.3 with 5 shots, s0)

    transformer_resnet50 are not delivered with pre-trained files. Why? And why we have to pointers to models, namely model_dir and resume_weights?

    Am I missing something?

    Thank you for your help!

    opened by bach05 1
  • RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

    RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)

    When I was training on a single machine with multiple GPUs, I encountered the following error. What is the reason?

    (base) lixiang@vs008:~/CWT-for-FSS$ sh scripts/ pascal 0 [0,1] 50 1
      0%|                                                                                       | 0/5953 [00:00<?, ?it/s]==> Running process rank 0.
    FB_param_noise: 0
    adapt_iter: 200
    arch: resnet
    augmentations: ['hor_flip', 'vert_flip', 'resize']
    backbone_dim: 2048
    batch_size: 2
    batch_size_val: 2
    bins: [1, 2, 3, 6]
    bottleneck_dim: 512
    ckpt_path: checkpoints/
    ckpt_used: best
    cls_lr: 0.1
    data_root: pascal/
    debug: False
    distributed: True
    dropout: 0.1
    episodic: True
    epochs: 20
    gamma: 0.1
    gpus: [0, 1]
    heads: 4
    image_size: 473
    iter_per_epoch: 6000
    layers: 50
    log_freq: 50
    lr_stepsize: 30
    m_scale: False
    main_optim: SGD
    manual_seed: 2021
    mean: [0.485, 0.456, 0.406]
    milestones: [40, 70]
    mixup: False
    model_dir: model_ckpt
    momentum: 0.9
    n_runs: 1
    nesterov: True
    norm_feat: True
    num_classes_tr: 2
    num_classes_val: 5
    padding_label: 255
    port: 53765
    pretrained: False
    random_shot: False
    resume_weights: /pretrained_models/
    rot_max: 10
    rot_min: -10
    save_models: True
    save_oracle: False
    scale_lr: 1.0
    scale_max: 2.0
    scale_min: 0.5
    scheduler: cosine
    shot: 1
    smoothing: True
    std: [0.229, 0.224, 0.225]
    test_name: default
    test_num: 1000
    test_split: default
    train_list: lists/pascal/train.txt
    train_name: pascal
    train_split: 0
    trans_lr: 0.001
    use_split_coco: False
    val_list: lists/pascal/val.txt
    weight_decay: 0.0001
    workers: 2
    => no weight found at '/pretrained_models/'
    Processing data for [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
      0%|                                                                                       | 0/5953 [00:00<?, ?it/s]==> Running process rank 1.
    FB_param_noise: 0
    adapt_iter: 200
    arch: resnet
    augmentations: ['hor_flip', 'vert_flip', 'resize']
    backbone_dim: 2048
    batch_size: 2
    batch_size_val: 2
    bins: [1, 2, 3, 6]
    bottleneck_dim: 512
    ckpt_path: checkpoints/
    ckpt_used: best
    cls_lr: 0.1
    data_root: pascal/
    debug: False
    distributed: True
    dropout: 0.1
    episodic: True
    epochs: 20
    gamma: 0.1
    gpus: [0, 1]
    heads: 4
    image_size: 473
    iter_per_epoch: 6000
    layers: 50
    log_freq: 50
    lr_stepsize: 30
    m_scale: False
    main_optim: SGD
    manual_seed: 2021
    mean: [0.485, 0.456, 0.406]
    milestones: [40, 70]
    mixup: False
    model_dir: model_ckpt
    momentum: 0.9
    n_runs: 1
    nesterov: True
    norm_feat: True
    num_classes_tr: 2
    num_classes_val: 5
    padding_label: 255
    port: 53765
    pretrained: False
    random_shot: False
    resume_weights: /pretrained_models/
    rot_max: 10
    rot_min: -10
    save_models: True
    save_oracle: False
    scale_lr: 1.0
    scale_max: 2.0
    scale_min: 0.5
    scheduler: cosine
    shot: 1
    smoothing: True
    std: [0.229, 0.224, 0.225]
    test_name: default
    test_num: 1000
    test_split: default
    train_list: lists/pascal/train.txt
    train_name: pascal
    train_split: 0
    trans_lr: 0.001
    use_split_coco: False
    val_list: lists/pascal/val.txt
    weight_decay: 0.0001
    workers: 2
    => no weight found at '/pretrained_models/'
    Processing data for [6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18, 19, 20]
    100%|███████████████████████████████████████████████████████████████████████████| 5953/5953 [00:08<00:00, 681.75it/s]
    100%|███████████████████████████████████████████████████████████████████████████| 5953/5953 [00:09<00:00, 609.93it/s]
      0%|                                                                                       | 0/1449 [00:00<?, ?it/s]INFO: pascal -> pascal
    INFO: 0 -> 0
    >> Start Filtering classes 
    >> Removed classes = [] 
    >> Kept classes = ['airplane', 'bicycle', 'bird', 'boat', 'bottle'] 
    Processing data for [1, 2, 3, 4, 5]
      0%|                                                                                       | 0/1449 [00:00<?, ?it/s]INFO: pascal -> pascal
    INFO: 0 -> 0
    >> Start Filtering classes 
    >> Removed classes = [] 
    >> Kept classes = ['airplane', 'bicycle', 'bird', 'boat', 'bottle'] 
    Processing data for [1, 2, 3, 4, 5]
    100%|███████████████████████████████████████████████████████████████████████████| 1449/1449 [00:06<00:00, 229.57it/s]
    100%|███████████████████████████████████████████████████████████████████████████| 1449/1449 [00:05<00:00, 241.58it/s]
    Traceback (most recent call last):
      File "/home/lixiang/anaconda3/lib/python3.8/", line 194, in _run_module_as_main
        return _run_code(code, main_globals, None,
      File "/home/lixiang/anaconda3/lib/python3.8/", line 87, in _run_code
        exec(code, run_globals)
      File "/home/lixiang/CWT-for-FSS/src/", line 360, in <module>
        mp.spawn(main_worker, args=(world_size, args), nprocs=world_size, join=True)
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/", line 230, in spawn
        return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/", line 188, in start_processes
        while not context.join():
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/", line 150, in join
        raise ProcessRaisedException(msg, error_index,
    -- Process 1 terminated with the following error:
    Traceback (most recent call last):
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/multiprocessing/", line 59, in _wrap
        fn(i, *args)
      File "/home/lixiang/CWT-for-FSS/src/", line 134, in main_worker
        _, _ = do_epoch(
      File "/home/lixiang/CWT-for-FSS/src/", line 266, in do_epoch
        output_support = binary_cls(f_s)
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/", line 1102, in _call_impl
        return forward_call(*input, **kwargs)
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/", line 446, in forward
        return self._conv_forward(input, self.weight, self.bias)
      File "/home/lixiang/anaconda3/lib/python3.8/site-packages/torch/nn/modules/", line 442, in _conv_forward
        return F.conv2d(input, weight, bias, self.stride,
    RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:1 and cuda:0! (when checking argument for argument weight in method wrapper__cudnn_convolution)
    opened by lixiang007666 1
  • Error while training the resnet model using pascal dataset

    Error while training the resnet model using pascal dataset

    Traceback (most recent call last):
      File "/usr/lib/python3.7/", line 193, in _run_module_as_main
        "__main__", mod_spec)
      File "/usr/lib/python3.7/", line 85, in _run_code
        exec(code, run_globals)
      File "/content/CWT-for-FSS/src/", line 360, in <module>
        mp.spawn(main_worker, args=(world_size, args), nprocs=world_size, join=True)
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/", line 230, in spawn
        return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/", line 188, in start_processes
        while not context.join():
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/", line 150, in join
        raise ProcessRaisedException(msg, error_index,
    -- Process 0 terminated with the following error:
    Traceback (most recent call last):
      File "/usr/lib/python3.7/multiprocessing/", line 121, in worker
        result = (True, func(*args, **kwds))
      File "/usr/lib/python3.7/multiprocessing/", line 44, in mapstar
        return list(map(*args))
      File "/content/CWT-for-FSS/src/dataset/", line 91, in process_image
        assert label_class_ in list(range(1, 81)), label_class_
    AssertionError: 147
    The above exception was the direct cause of the following exception:
    Traceback (most recent call last):
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/", line 59, in _wrap
        fn(i, *args)
      File "/content/CWT-for-FSS/src/", line 117, in main_worker
        train_loader, train_sampler = get_train_loader(args)
      File "/content/CWT-for-FSS/src/dataset/", line 44, in get_train_loader
        mode_train=True, transform=train_transform, class_list=class_list, args=args
      File "/content/CWT-for-FSS/src/dataset/", line 114, in __init__
        self.data_list, self.sub_class_file_list = make_dataset(args.data_root, args.train_list, self.class_list)
      File "/content/CWT-for-FSS/src/dataset/", line 55, in make_dataset
        for sublist, subdict in mmap_(process_partial, tqdm(list_read)):
      File "/content/CWT-for-FSS/src/dataset/", line 17, in mmap_
        return Pool().map(fn, iter)
      File "/usr/lib/python3.7/multiprocessing/", line 268, in map
        return self._map_async(func, iterable, mapstar, chunksize).get()
      File "/usr/lib/python3.7/multiprocessing/", line 657, in get
        raise self._value
    AssertionError: 147
    opened by Hemanth-Gattu 11
