PyTorch implementation of SimSiam: Exploring Simple Siamese Representation Learning

Overview

SimSiam: Exploring Simple Siamese Representation Learning

simsiam

This is a PyTorch implementation of the SimSiam paper:

@Article{chen2020simsiam,
  author  = {Xinlei Chen and Kaiming He},
  title   = {Exploring Simple Siamese Representation Learning},
  journal = {arXiv preprint arXiv:2011.10566},
  year    = {2020},
}

Preparation

Install PyTorch and download the ImageNet dataset following the official PyTorch ImageNet training code. Similar to MoCo, the code release contains minimal modifications for both unsupervised pre-training and linear classification to that code.

In addition, install apex for the LARS implementation needed for linear classification.

Unsupervised Pre-Training

Only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported.

To do unsupervised pre-training of a ResNet-50 model on ImageNet in an 8-gpu machine, run:

python main_simsiam.py \
  -a resnet50 \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  --fix-pred-lr \
  [your imagenet-folder with train and val folders]

The script uses all the default hyper-parameters as described in the paper, and uses the default augmentation recipe from MoCo v2.

The above command performs pre-training with a non-decaying predictor learning rate for 100 epochs, corresponding to the last row of Table 1 in the paper.

Linear Classification

With a pre-trained model, to train a supervised linear classifier on frozen features/weights in an 8-gpu machine, run:

python main_lincls.py \
  -a resnet50 \
  --dist-url 'tcp://localhost:10001' --multiprocessing-distributed --world-size 1 --rank 0 \
  --pretrained [your checkpoint path]/checkpoint_0099.pth.tar \
  --lars \
  [your imagenet-folder with train and val folders]

The above command uses LARS optimizer and a default batch size of 4096.

Models and Logs

Our pre-trained ResNet-50 models and logs:

pre-train
epochs
batch
size
pre-train
ckpt
pre-train
log
linear cls.
ckpt
linear cls.
log
top-1 acc.
100 512 link link link link 68.1
100 256 link link link link 68.3

Settings for the above: 8 NVIDIA V100 GPUs, CUDA 10.1/CuDNN 7.6.5, PyTorch 1.7.0.

Transferring to Object Detection

Same as MoCo for object detection transfer, please see moco/detection.

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

Comments
  • Checkpoint with 200 epochs?

    Checkpoint with 200 epochs?

    Hello, can you provide the checkpoint with 200 epochs and 256 batch size in imagenet-1k? It will be better if the checkpoint include the last fc and predictor in addition to the backbone.

    opened by lzyhha 7
  • Failing on Reproducing ImageNet Linear Classification Results with SGD

    Failing on Reproducing ImageNet Linear Classification Results with SGD

    Hello,

    I'm trying to reproduce the ImageNet results with SGD in a DDP setting with 8 GPUs and batch size 256, learning rate 30, weight decay 0, momentum 0.9, 100 epochs. The linear classification performance according to the paper is yields 1% lower accuracy which would be 68,1-1=~67,1%. However, with the bs 256 pre-training checkpoint provided I can only get to Acc@1 65.080 Acc@5 86.696. Any idea what I could do to match the described performance? Hints on how to set the hyperparameters? I am using the original code from this repo.

    Here's the output of the last valid batch:

    Epoch: [99][4980/5005]  Time  0.401 ( 0.217)    Data  0.362 ( 0.049)    Loss 1.6190e+00 (1.5670e+00)    Acc@1  56.25 ( 63.85)   Acc@5  87.50 ( 84.66)
    Epoch: [99][4990/5005]  Time  0.060 ( 0.217)    Data  0.012 ( 0.049)    Loss 1.8884e+00 (1.5670e+00)    Acc@1  56.25 ( 63.85)   Acc@5  87.50 ( 84.66)
    Epoch: [99][5000/5005]  Time  0.057 ( 0.217)    Data  0.017 ( 0.050)    Loss 1.1361e+00 (1.5666e+00)    Acc@1  68.75 ( 63.86)   Acc@5  90.62 ( 84.66)
    Test: [  0/196] Time 12.220 (12.220)    Loss 8.8439e-01 (8.8439e-01)    Acc@1  77.34 ( 77.34)   Acc@5  94.92 ( 94.92)
    Test: [ 10/196] Time  0.276 ( 2.376)    Loss 1.3778e+00 (1.0927e+00)    Acc@1  64.06 ( 72.16)   Acc@5  89.84 ( 91.73)
    Test: [ 20/196] Time  0.302 ( 1.952)    Loss 1.1767e+00 (1.0888e+00)    Acc@1  78.12 ( 73.05)   Acc@5  87.50 ( 91.15)
    Test: [ 30/196] Time  0.274 ( 1.867)    Loss 1.2236e+00 (1.0760e+00)    Acc@1  67.97 ( 73.44)   Acc@5  91.80 ( 91.33)
    Test: [ 40/196] Time  0.280 ( 1.672)    Loss 1.2726e+00 (1.1910e+00)    Acc@1  67.58 ( 69.84)   Acc@5  92.97 ( 90.62)
    Test: [ 50/196] Time  0.274 ( 1.728)    Loss 8.7860e-01 (1.1958e+00)    Acc@1  77.34 ( 69.55)   Acc@5  94.92 ( 90.83)
    Test: [ 60/196] Time  0.275 ( 1.664)    Loss 1.4648e+00 (1.1892e+00)    Acc@1  64.84 ( 69.67)   Acc@5  88.28 ( 91.12)
    Test: [ 70/196] Time  0.312 ( 1.662)    Loss 1.0541e+00 (1.1579e+00)    Acc@1  73.05 ( 70.44)   Acc@5  91.80 ( 91.43)
    Test: [ 80/196] Time  0.274 ( 1.663)    Loss 1.9151e+00 (1.1739e+00)    Acc@1  51.17 ( 70.04)   Acc@5  80.47 ( 91.08)
    Test: [ 90/196] Time  0.285 ( 1.715)    Loss 2.3642e+00 (1.2366e+00)    Acc@1  47.27 ( 68.99)   Acc@5  74.22 ( 90.19)
    Test: [100/196] Time  0.279 ( 1.648)    Loss 2.1153e+00 (1.2950e+00)    Acc@1  51.17 ( 67.87)   Acc@5  75.78 ( 89.36)
    Test: [110/196] Time  0.274 ( 1.647)    Loss 1.2629e+00 (1.3154e+00)    Acc@1  70.70 ( 67.53)   Acc@5  88.67 ( 89.03)
    Test: [120/196] Time  0.286 ( 1.624)    Loss 1.7559e+00 (1.3332e+00)    Acc@1  64.06 ( 67.34)   Acc@5  81.64 ( 88.72)
    Test: [130/196] Time  0.279 ( 1.634)    Loss 1.2278e+00 (1.3690e+00)    Acc@1  70.70 ( 66.59)   Acc@5  90.23 ( 88.21)
    Test: [140/196] Time  0.282 ( 1.584)    Loss 1.6165e+00 (1.3958e+00)    Acc@1  63.67 ( 66.11)   Acc@5  86.72 ( 87.85)
    Test: [150/196] Time  0.273 ( 1.574)    Loss 1.6683e+00 (1.4218e+00)    Acc@1  68.36 ( 65.76)   Acc@5  80.86 ( 87.37)
    Test: [160/196] Time  0.275 ( 1.537)    Loss 1.2704e+00 (1.4396e+00)    Acc@1  70.70 ( 65.53)   Acc@5  88.28 ( 87.07)
    Test: [170/196] Time  0.277 ( 1.514)    Loss 1.0841e+00 (1.4603e+00)    Acc@1  74.22 ( 65.04)   Acc@5  93.36 ( 86.76)
    Test: [180/196] Time  0.275 ( 1.472)    Loss 1.5453e+00 (1.4728e+00)    Acc@1  63.28 ( 64.80)   Acc@5  88.28 ( 86.58)
    Test: [190/196] Time  0.274 ( 1.446)    Loss 1.3819e+00 (1.4701e+00)    Acc@1  64.45 ( 64.85)   Acc@5  91.80 ( 86.61)
     * Acc@1 65.080 Acc@5 86.696
    

    Thanks!

    opened by ferreirafabio 6
  • SyncBatchNorm

    SyncBatchNorm

    hi~ when I run the code, there's an error occurred in the bn1 part.

    z1 = self.encoder(x1)  # NxC
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torchvision/models/resnet.py", line 204, in forward
        x = self.fc(x)
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/nn/modules/container.py", line 92, in forward
        input = module(input)
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/nn/modules/module.py", line 493, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 440, in forward
        self._check_input_dim(input)
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/nn/modules/batchnorm.py", line 425, in _check_input_dim
        .format(input.dim()))
    ValueError: expected at least 3D input (got 2D input)
    

    it found some possible solutions, saying that, change the input (N, C) as (N,1, C). I tried this but another error occurred:

    Traceback (most recent call last):
      File "main_simsiam.py", line 370, in <module>
        main()
      File "main_simsiam.py", line 122, in main
        mp.spawn(main_worker, nprocs=ngpus_per_node, args=(ngpus_per_node, args))
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 167, in spawn
        while not spawn_context.join():
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 114, in join
        raise Exception(msg)
    Exception:
    
    -- Process 2 terminated with the following error:
    Traceback (most recent call last):
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/multiprocessing/spawn.py", line 19, in _wrap
        fn(i, *args)
      File "/ssd2/lixingjian/dengandong/InterpSSL/SimSiam/main_simsiam.py", line 260, in main_worker
        train(train_loader, model, criterion, optimizer, epoch, args)
      File "/ssd2/lixingjian/dengandong/InterpSSL/SimSiam/main_simsiam.py", line 301, in train
        loss.backward()
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/tensor.py", line 107, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph)
      File "/home/work/anaconda3/envs/pt1.1_py3.6_cuda9.0/lib/python3.6/site-packages/torch/autograd/__init__.py", line 93, in backward
        allow_unreachable=True)  # allow_unreachable flag
    RuntimeError: Function SyncBatchNormBackward returned an invalid gradient at index 1 - got [1] but expected shape compatible with [512]
    

    I'd appreciate if you could provide some suggestions or maybe solutions. Thank you!

    opened by dengandong 5
  • loss question

    loss question

    I use others data augmentation. the loss go down to -1 in the first epoch, and then go up slowly in the next epoch? Have you meet the same problem when you try differnet data augmentation?

    opened by daeing 4
  • [Problem] problem occured when trained on custom dataset

    [Problem] problem occured when trained on custom dataset

    Hi, Thanks for your excellent work! I want to train on my own dataset which consist of many different sub-dir paths, so warite a PyTorch Dataset with input of train.txt(a list of img paths from different sub-dir paths) as beblow:

    class DatasetFromTxtList(Dataset):
        def __init__(self, txt_path):
            """
            Read data path from a TXT list file
            """
            if not os.path.isfile(txt_path):
                print("[Err]: invalid txt file path.")
                exit(-1)
    
            self.img_paths = []
            with open(txt_path, "r", encoding="utf-8") as f:
                for line in f.readlines():
                    img_path = line.strip()
                    self.img_paths.append(img_path)
            print("Total {:d} images found.".format(len(self.img_paths)))
    
            ## Define transformations
            self.normalize = transforms.Normalize(mean=[0.485, 0.456, 0.406],
                                                  std=[0.229, 0.224, 0.225])
    
            # MoCo v2's aug: similar to SimCLR https://arxiv.org/abs/2002.05709
            augmentations = [
                transforms.RandomResizedCrop(224, scale=(0.2, 1.0)),
                transforms.RandomApply([
                    transforms.ColorJitter(0.4, 0.4, 0.4, 0.1)  # not strengthened
                ], p=0.8),
                transforms.RandomGrayscale(p=0.2),
                transforms.RandomApply([GaussianBlur([0.1, 2.0])], p=0.5),
                transforms.RandomHorizontalFlip(),
                transforms.ToTensor(),
                self.normalize
            ]
    
            self.T = transforms.Compose(augmentations)
    
        def __getitem__(self, idx):
            """
            """
            img_path = self.img_paths[idx]
            x = Image.open(img_path)
    
            q = self.T(x)
            k = self.T(x)
    
            return [q, k]
    
        def __len__(self):
            """
            """
            return len(self.img_paths)
    
    

    and, i replace the train_dataset definition with:

        ## ----- Using customized dataset: reading sample from a txt list file...
        train_dataset = DatasetFromTxtList(args.train_txt)
    

    instead of:

        train_dataset = datasets.ImageFolder(
            train_dir,
            simsiam.loader.TwoCropsTransform(transforms.Compose(augmentation)))
    

    error is as follows:

    Total 502335 images found.
    Traceback (most recent call last):
      File "/mnt/diskb/even/SimSiam/my_simsiam.py", line 468, in <module>
        main()
      File "/mnt/diskb/even/SimSiam/my_simsiam.py", line 203, in main
        mp.spawn(main_worker, nprocs=n_gpus_per_node, args=(n_gpus_per_node, args))
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py", line 230, in spawn
        return start_processes(fn, args, nprocs, join, daemon, start_method='spawn')
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py", line 188, in start_processes
        while not context.join():
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py", line 150, in join
        raise ProcessRaisedException(msg, error_index, failed_process.pid)
    torch.multiprocessing.spawn.ProcessRaisedException: 
    
    -- Process 3 terminated with the following error:
    Traceback (most recent call last):
      File "/usr/local/lib/python3.7/dist-packages/torch/multiprocessing/spawn.py", line 59, in _wrap
        fn(i, *args)
      File "/mnt/diskb/even/SimSiam/my_simsiam.py", line 354, in main_worker
        train(train_loader, model, criterion, optimizer, epoch, args)
      File "/mnt/diskb/even/SimSiam/my_simsiam.py", line 391, in train
        p1, p2, z1, z2 = model(x1=images[0], x2=images[1])
      File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/usr/local/lib/python3.7/dist-packages/torch/nn/parallel/distributed.py", line 799, in forward
        output = self.module(*inputs[0], **kwargs[0])
      File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/mnt/diskb/even/SimSiam/simsiam/builder.py", line 55, in forward
        z1 = self.encoder(x1) # NxC
      File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py", line 249, in forward
        return self._forward_impl(x)
      File "/usr/local/lib/python3.7/dist-packages/torchvision/models/resnet.py", line 232, in _forward_impl
        x = self.conv1(x)
      File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 1051, in _call_impl
        return forward_call(*input, **kwargs)
      File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py", line 443, in forward
        return self._conv_forward(input, self.weight, self.bias)
      File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py", line 440, in _conv_forward
        self.padding, self.dilation, self.groups)
    RuntimeError: Expected 4-dimensional input for 4-dimensional weight [64, 3, 7, 7], but got 3-dimensional input of size [3, 224, 224] instead
    
    

    How to solve this?

    opened by CaptainEven 3
  • Single GPU Training

    Single GPU Training

    In the readme, it stated that,

    Only multi-gpu, DistributedDataParallel training is supported; single-gpu or DataParallel training is not supported.

    But when I saw the code, i think we can choose to not using DistributedDataParallel, does using the single gpu will affect any performance?

    opened by nixczhou 3
  • What's the point if we do not gather all outputs in different GPUs to compute contrastive loss

    What's the point if we do not gather all outputs in different GPUs to compute contrastive loss

    Hi,

    this is a really great work. However, I have a general question for the contrastive loss.

    In your code, you use 8GPUs for a total batch size of 256. It means 32 samples in one GPU. You compute the contrastive loss of these 32 samples on the same GPU firstly, then gather the loss from different GPUs to compute the final gradient.

    However, it makes little sense for me to use this way to increase the batch size. One challenge for the contrastive loss is to find hard negative. Normally we increase the batch size on one single GPU to handle this problem. Since larger batch size offer us more possibility to find hard negatives. But if we use DDP, this kind of larger total batch size is not useful.

    For example, I use 16 GPUs for a total batch size of 512. This will result in the same number of samples (32) on one GPU as above. Would it better to gather all of the output embeddings from different GPUs to one GPU to compute the contrastive loss?

    In Table 2 of your paper, how do your change the batch size? Increasing the samples on a single GPU and fix the number of GPUs, or increasing the number of GPUs and fix the number of samples on a single GPU? The result is a little weird for me, total batch size of 4096 is the worst.

    opened by BaohaoLiao 3
  • AssertionError:   assert set(msg.missing_keys) == {

    AssertionError: assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}

    My pretraining phase went well, but now when I try loading the checkpoint to train the classifier it breaks. The following is what I am doing, which is based on the code in this repo:

    model = torchvision.models.__dict__['resnet18']()
    for name, param in model.named_parameters():
        if name not in ['fc.weight', 'fc.bias']:
            param.requires_grad = False
    # init the fc layer
    model.fc.weight.data.normal_(mean=0.0, std=0.01)
    model.fc.bias.data.zero_()
    
    ### Load checkpoint
    checkpoint = torch.load('./simsiam_malaria/resnet18_checkpoint.pth.tar', map_location="cpu")
    
    for k in list(state_dict.keys()):
        print(k)
    
    encoder.bn1.weight
    encoder.bn1.bias
    encoder.bn1.running_mean
    encoder.bn1.running_var
    encoder.bn1.num_batches_tracked
    encoder.layer1.0.conv1.weight
    encoder.layer1.0.bn1.weight
    encoder.layer1.0.bn1.bias
    encoder.layer1.0.bn1.running_mean
    encoder.layer1.0.bn1.running_var
    encoder.layer1.0.bn1.num_batches_tracked
    encoder.layer1.0.conv2.weight
    encoder.layer1.0.bn2.weight
    encoder.layer1.0.bn2.bias
    encoder.layer1.0.bn2.running_mean
    encoder.layer1.0.bn2.running_var
    encoder.layer1.0.bn2.num_batches_tracked
    encoder.layer1.1.conv1.weight
    encoder.layer1.1.bn1.weight
    encoder.layer1.1.bn1.bias
    encoder.layer1.1.bn1.running_mean
    encoder.layer1.1.bn1.running_var
    encoder.layer1.1.bn1.num_batches_tracked
    encoder.layer1.1.conv2.weight
    encoder.layer1.1.bn2.weight
    encoder.layer1.1.bn2.bias
    encoder.layer1.1.bn2.running_mean
    encoder.layer1.1.bn2.running_var
    encoder.layer1.1.bn2.num_batches_tracked
    encoder.layer2.0.conv1.weight
    encoder.layer2.0.bn1.weight
    encoder.layer2.0.bn1.bias
    encoder.layer2.0.bn1.running_mean
    encoder.layer2.0.bn1.running_var
    encoder.layer2.0.bn1.num_batches_tracked
    encoder.layer2.0.conv2.weight
    encoder.layer2.0.bn2.weight
    encoder.layer2.0.bn2.bias
    encoder.layer2.0.bn2.running_mean
    encoder.layer2.0.bn2.running_var
    encoder.layer2.0.bn2.num_batches_tracked
    encoder.layer2.0.downsample.0.weight
    encoder.layer2.0.downsample.1.weight
    encoder.layer2.0.downsample.1.bias
    encoder.layer2.0.downsample.1.running_mean
    encoder.layer2.0.downsample.1.running_var
    encoder.layer2.0.downsample.1.num_batches_tracked
    encoder.layer2.1.conv1.weight
    encoder.layer2.1.bn1.weight
    encoder.layer2.1.bn1.bias
    encoder.layer2.1.bn1.running_mean
    encoder.layer2.1.bn1.running_var
    encoder.layer2.1.bn1.num_batches_tracked
    encoder.layer2.1.conv2.weight
    encoder.layer2.1.bn2.weight
    encoder.layer2.1.bn2.bias
    encoder.layer2.1.bn2.running_mean
    encoder.layer2.1.bn2.running_var
    encoder.layer2.1.bn2.num_batches_tracked
    encoder.layer3.0.conv1.weight
    encoder.layer3.0.bn1.weight
    encoder.layer3.0.bn1.bias
    encoder.layer3.0.bn1.running_mean
    encoder.layer3.0.bn1.running_var
    encoder.layer3.0.bn1.num_batches_tracked
    encoder.layer3.0.conv2.weight
    encoder.layer3.0.bn2.weight
    encoder.layer3.0.bn2.bias
    encoder.layer3.0.bn2.running_mean
    encoder.layer3.0.bn2.running_var
    encoder.layer3.0.bn2.num_batches_tracked
    encoder.layer3.0.downsample.0.weight
    encoder.layer3.0.downsample.1.weight
    encoder.layer3.0.downsample.1.bias
    encoder.layer3.0.downsample.1.running_mean
    encoder.layer3.0.downsample.1.running_var
    encoder.layer3.0.downsample.1.num_batches_tracked
    encoder.layer3.1.conv1.weight
    encoder.layer3.1.bn1.weight
    encoder.layer3.1.bn1.bias
    encoder.layer3.1.bn1.running_mean
    encoder.layer3.1.bn1.running_var
    encoder.layer3.1.bn1.num_batches_tracked
    encoder.layer3.1.conv2.weight
    encoder.layer3.1.bn2.weight
    encoder.layer3.1.bn2.bias
    encoder.layer3.1.bn2.running_mean
    encoder.layer3.1.bn2.running_var
    encoder.layer3.1.bn2.num_batches_tracked
    encoder.layer4.0.conv1.weight
    encoder.layer4.0.bn1.weight
    encoder.layer4.0.bn1.bias
    encoder.layer4.0.bn1.running_mean
    encoder.layer4.0.bn1.running_var
    encoder.layer4.0.bn1.num_batches_tracked
    encoder.layer4.0.conv2.weight
    encoder.layer4.0.bn2.weight
    encoder.layer4.0.bn2.bias
    encoder.layer4.0.bn2.running_mean
    encoder.layer4.0.bn2.running_var
    encoder.layer4.0.bn2.num_batches_tracked
    encoder.layer4.0.downsample.0.weight
    encoder.layer4.0.downsample.1.weight
    encoder.layer4.0.downsample.1.bias
    encoder.layer4.0.downsample.1.running_mean
    encoder.layer4.0.downsample.1.running_var
    encoder.layer4.0.downsample.1.num_batches_tracked
    encoder.layer4.1.conv1.weight
    encoder.layer4.1.bn1.weight
    encoder.layer4.1.bn1.bias
    encoder.layer4.1.bn1.running_mean
    encoder.layer4.1.bn1.running_var
    encoder.layer4.1.bn1.num_batches_tracked
    encoder.layer4.1.conv2.weight
    encoder.layer4.1.bn2.weight
    encoder.layer4.1.bn2.bias
    encoder.layer4.1.bn2.running_mean
    encoder.layer4.1.bn2.running_var
    encoder.layer4.1.bn2.num_batches_tracked
    encoder.fc.0.weight
    encoder.fc.1.weight
    encoder.fc.1.bias
    encoder.fc.1.running_mean
    encoder.fc.1.running_var
    encoder.fc.1.num_batches_tracked
    encoder.fc.3.weight
    encoder.fc.4.weight
    encoder.fc.4.bias
    encoder.fc.4.running_mean
    encoder.fc.4.running_var
    encoder.fc.4.num_batches_tracked
    encoder.fc.6.weight
    encoder.fc.6.bias
    encoder.fc.7.running_mean
    encoder.fc.7.running_var
    encoder.fc.7.num_batches_tracked
    predictor.0.weight
    predictor.1.weight
    predictor.1.bias
    predictor.1.running_mean
    predictor.1.running_var
    predictor.1.num_batches_tracked
    predictor.3.weight
    predictor.3.bias
    

    Now when I run the:

    for k in list(state_dict.keys()):
        # retain only encoder up to before the embedding layer
        if k.startswith('module.encoder') and not k.startswith('module.encoder.fc'):
            # remove prefix
            state_dict[k[len("module.encoder."):]] = state_dict[k]
        # delete renamed or unused k
        del state_dict[k]
    msg = model.load_state_dict(state_dict, strict=False)
    assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
    

    I get the output:

    AssertionError                            Traceback (most recent call last)
    <ipython-input-7-3429f0c9d366> in <module>
          8     del state_dict[k]
          9 msg = model.load_state_dict(state_dict, strict=False)
    ---> 10 assert set(msg.missing_keys) == {"fc.weight", "fc.bias"}
    
    AssertionError: 
    

    The del state_dict[k] deletes every single key

    opened by etetteh 3
  • A question about projection and prediction MLP

    A question about projection and prediction MLP

    Hi Xinlei, great work and thanks for your sharing. May I ask some questions about the config of the projection and the prediction MLP in this implementation.

    1. For the projection MLP, I saw the same dimension used for all three FC layers (2048 when using Resnet50). Does dimension have to be the same for all layers here? I mean what if my input_dim is 512, does the hidden_dim need to be set as 512 as well?

    2. For the prediction MLP, a point you mentioned in the paper is that the prediction MLP’s hidden layer dimension is always 1/4 of the output dimension. We find that this bottleneck structure is more robust. If we set the hidden dimension to be equal to the output dimension, the train-ing can be less stable or fail in some variants of our exploration .I was wondering is it necessary to keep the hidden_dim as 1/4 as output dimension even with the small output dimension (such as 512 instead fo 2048)?

    Looking forward to your reply!

    opened by CaptainPrice12 3
  • Open-source improved implementations of BYOL, SWAV, etc.?

    Open-source improved implementations of BYOL, SWAV, etc.?

    Thank you so much for open-sourcing! The code looks extremely clean and nice. It is a great service to the community!

    Would you also open-source the improved implementations of BYOL, SWAV, SimCLR and MoCoV2? Devils are in the details, so it would be great to reproduce the improved baselines results in the paper as well.

    Thanks again for your wonderful work!

    opened by LinxiFan 3
  • SyncBatchnorm usage in main_lincls.py

    SyncBatchnorm usage in main_lincls.py

    Hello, Thanks for your great work. I have a short question.

    Is there a reason why you use syncbatchnorm in main_simsiam.py only? I can't find use of it in main_lincls.py. Doesn't it matter?

    Thanks!

    opened by myoh97 2
  • The value of loss function nn.CosineSimilarity is negative

    The value of loss function nn.CosineSimilarity is negative

    @endernewton , Hi, Dr Chen, thank you for a high quality work, I meet the case that the value of loss function nn.CosineSimilarity is negative, it happens when the backbone is resnet12 and the dataset is CIFA-FS. Can you help me solve this issue?

    opened by zhang1hongliang 1
  • Pre-trained weights for SimCLR, MoCov2, BYOL, SwaV

    Pre-trained weights for SimCLR, MoCov2, BYOL, SwaV

    Hi!

    Thanks so much for releasing the code and pre-trained models!

    That would be really wonderful if you could release the weights for pre-trained improved SimCLR, MoCov2, BYOL, and SwaV, reported in the paper. Is there any chance that you can do it?

    opened by ninatu 1
  • About the shape of image[0] and image[1]

    About the shape of image[0] and image[1]

    Hello,thanks for your project.I just wanna put a question about the image[0] and iamge[1].Are they have 3 channels rgb and have the shape of [x, y,z]?

    opened by Richard-Lu-badbird 1
  • Loss collapse during training

    Loss collapse during training

    I am trying to pretrain the SimSiam model on mscoco dataset.. but the loss collapses to -1 very quickly.. What are the possible reasons behind and some suggestions to solve the same?

    opened by aisagarw 1
  • About the projection and prediction head dimension configs

    About the projection and prediction head dimension configs

    Hi,

    Thank you for the amazing works and I am inspired a lot from it!

    One thing I want to ask is how you designed the structure of your projection and predications heads, i.e., how to decide the number of layers and hidden and output dimensions? P.S. I notice that you have already mentioned the bottleneck structure of prediction MLP is helpful.

    Thank you for the helps and time!

    opened by Dylan-H-Wang 0
  • Slow convergence with SGD linear evaluation

    Slow convergence with SGD linear evaluation

    Hi!

    I am running a linear evaluation right now on a simsiam network I've just trained. It's on a different repository. In contrast to the evaluation protocol you've written, I use another one preferred by a few other papers: 256bs, 100 epoch, SGD with momentum, 0.3 lr, 0 weight decay.

    My first intuition was that my code had a bug, because even when I used the weights you shared in this repository, my evaluation started off with 5% accuracy after the first epoch, which is somewhat close to random weights' performance. Now as few epochs passed I see some progress, maybe I will have 30%+ after 10 epochs. However, other self-supervised methods kick off this evaluation with 60% right after the first epoch.

    Do you have any guesses why I experience low convergence with simsiam?

    Thank you.

    opened by gergopool 0
Owner
Facebook Research
Facebook Research
Exploring Simple Siamese Representation Learning

G-SimSiam A PyTorch implementation which refers to repo for the paper Exploring Simple Siamese Representation Learning by Xinlei Chen & Kaiming He Add

zhuyun 1 Dec 19, 2021
Pytorch implementation of SimSiam Architecture

SimSiam-pytorch A simple pytorch implementation of Exploring Simple Siamese Representation Learning which is developed by Facebook AI Research (FAIR)

Saeed Shurrab 1 Oct 20, 2021
A PyTorch implementation of "Multi-Scale Contrastive Siamese Networks for Self-Supervised Graph Representation Learning", IJCAI-21

MERIT A PyTorch implementation of our IJCAI-21 paper Multi-Scale Contrastive Siamese Networks for Self-Supervised Graph Representation Learning. Depen

Graph Analysis & Deep Learning Laboratory, GRAND 32 Jan 2, 2023
[CVPR 2022 Oral] Crafting Better Contrastive Views for Siamese Representation Learning

Crafting Better Contrastive Views for Siamese Representation Learning (CVPR 2022 Oral) 2022-03-29: The paper was selected as a CVPR 2022 Oral paper! 2

null 249 Dec 28, 2022
PyTorch implementation of Asymmetric Siamese (https://arxiv.org/abs/2204.00613)

Asym-Siam: On the Importance of Asymmetry for Siamese Representation Learning This is a PyTorch implementation of the Asym-Siam paper, CVPR 2022: @inp

Meta Research 89 Dec 18, 2022
Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning, CVPR 2021

Propagate Yourself: Exploring Pixel-Level Consistency for Unsupervised Visual Representation Learning By Zhenda Xie*, Yutong Lin*, Zheng Zhang, Yue Ca

Zhenda Xie 293 Dec 20, 2022
Exploring Visual Engagement Signals for Representation Learning

Exploring Visual Engagement Signals for Representation Learning Menglin Jia, Zuxuan Wu, Austin Reiter, Claire Cardie, Serge Belongie and Ser-Nam Lim C

Menglin Jia 9 Jul 23, 2022
Official project website for the CVPR 2021 paper "Exploring intermediate representation for monocular vehicle pose estimation"

EgoNet Official project website for the CVPR 2021 paper "Exploring intermediate representation for monocular vehicle pose estimation". This repo inclu

Shichao Li 138 Dec 9, 2022
The official implementation of paper Siamese Transformer Pyramid Networks for Real-Time UAV Tracking, accepted by WACV22

SiamTPN Introduction This is the official implementation of the SiamTPN (WACV2022). The tracker intergrates pyramid feature network and transformer in

Robotics and Intelligent Systems Control @ NYUAD 28 Nov 25, 2022
Python library containing BART query generation and BERT-based Siamese models for neural retrieval.

Neural Retrieval Embedding-based Zero-shot Retrieval through Query Generation leverages query synthesis over large corpuses of unlabeled text (such as

Amazon Web Services - Labs 35 Apr 14, 2022
Classify bird species based on their songs using SIamese Networks and 1D dilated convolutions.

The goal is to classify different birds species based on their songs/calls. Spectrograms have been extracted from the audio samples and used as features for classification.

Aditya Dutt 9 Dec 27, 2022
Official code for 'Robust Siamese Object Tracking for Unmanned Aerial Manipulator' and offical introduction to UAMT100 benchmark

SiamSA: Robust Siamese Object Tracking for Unmanned Aerial Manipulator Demo video ?? Our video on Youtube and bilibili demonstrates the evaluation of

Intelligent Vision for Robotics in Complex Environment 12 Dec 18, 2022
Siamese TabNet

Raifhack-DS-2021 https://raifhack.ru/ - Команда Звёздочка Siamese TabNet Сиамская TabNet предсказывает стоимость объекта недвижимости с price_type=1,

Daniel Gafni 15 Apr 16, 2022
From this paper "SESNet: A Semantically Enhanced Siamese Network for Remote Sensing Change Detection"

SESNet for remote sensing image change detection It is the implementation of the paper: "SESNet: A Semantically Enhanced Siamese Network for Remote Se

null 1 May 24, 2022
pytorch implementation of "Contrastive Multiview Coding", "Momentum Contrast for Unsupervised Visual Representation Learning", and "Unsupervised Feature Learning via Non-Parametric Instance-level Discrimination"

Unofficial implementation: MoCo: Momentum Contrast for Unsupervised Visual Representation Learning (Paper) InsDis: Unsupervised Feature Learning via N

Zhiqiang Shen 16 Nov 4, 2020
Eff video representation - Efficient video representation through neural fields

Neural Residual Flow Fields for Efficient Video Representations 1. Download MPI

null 41 Jan 6, 2023
[SIGIR22] Official PyTorch implementation for "CORE: Simple and Effective Session-based Recommendation within Consistent Representation Space".

CORE This is the official PyTorch implementation for the paper: Yupeng Hou, Binbin Hu, Zhiqiang Zhang, Wayne Xin Zhao. CORE: Simple and Effective Sess

RUCAIBox 26 Dec 19, 2022