PyTorch implementation of Barlow Twins.

Overview

Barlow Twins: Self-Supervised Learning via Redundancy Reduction

fig_method

PyTorch implementation of Barlow Twins.

@article{zbontar2021barlow,
  title={Barlow Twins: Self-Supervised Learning via Redundancy Reduction},
  author={Zbontar, Jure and Jing, Li and Misra, Ishan and LeCun, Yann and Deny, St{\'e}phane},
  journal={arXiv preprint arXiv:2103.03230},
  year={2021}
}

Pretrained Model

epochs batch size acc1 acc5 download
1000 2048 73.3% 91.0% model (logs)

The pretrained model is also available on PyTorch Hub.

import torch
model = torch.hub.load('facebookresearch/barlowtwins:main', 'resnet50')

Barlow Twins Training

Install PyTorch and download ImageNet by following the instructions in the requirements section of the PyTorch ImageNet training example. The code has been developed for PyTorch version 1.7.1 and torchvision version 0.8.2, but it should work with other versions just as well.

Our best model is obtained by running the following command:

python main.py /path/to/imagenet/ --epochs 1000 --batch-size 2048 --learning-rate 0.2 --lambd 0.0051 --projector 8192-8192-8192 --scale-loss 0.024

Training time is approximately 7 days on 16 v100 GPUs.

Evaluation: Linear Classification

Train a linear probe on the representations learned by Barlow Twins. Freeze the weights of the resnet and use the entire ImageNet training set.

python evaluate.py /path/to/imagenet/ /path/to/checkpoint/resnet50.pth --lr-classifier 0.1

Evaluation: Semi-supervised Learning

Train a linear probe on the representations learned by Barlow Twins. Finetune the weights of the resnet and use a subset of the ImageNet training set.

python evaluate.py /path/to/imagenet/ /path/to/checkpoint/resnet50.pth --weights finetune --train-perc 1 --epochs 20 --lr-backbone 0.002 --lr-classifier 0.5 --weight-decay 0

Issues

In order to match the code that was used to develop Barlow Twins, we include an additional parameter, --scale-loss, that multiplies the loss by a constant factor. We are working on a version that will not require this parameter.

License

This project is under the CC-BY-NC 4.0 license. See LICENSE for details.

Comments
  • configs of training with batch size 256

    configs of training with batch size 256

    can you provide the config to train with batch size 256? how to set the lambd and scale loss? and does the LARS is still necessary when training with small batch sizes?

    opened by poodarchu 10
  • Is torch.distributed.all_reduce working as expected?

    Is torch.distributed.all_reduce working as expected?

    This line https://github.com/facebookresearch/barlowtwins/blob/main/main.py#L208 use torch.distributed.all_reduce to sum the correlation matrices across all gpus. However as I know this op is not dedicated for forward computation where backward computation would run later. Instead, to apply "correctly differentiable" distributed all reduce, the official PyTorch document recommends using torch.distributed.nn.*: https://pytorch.org/docs/stable/distributed.html#autograd-enabled-communication-primitives

    opened by WarBean 8
  • BT loss value on val set every N training epochs without classifier

    BT loss value on val set every N training epochs without classifier

    Hi,

    Thanks for making it painless for others to use and build upon your work.

    I'm curious to know what the loss function value looks like during the training process on the val set of images. In my case of using spectrogram images on a resnet-18 backbone, the network trains well but the val loss value is very noisy and shows no clear trend. For example, I'm not sure what to make of this:

    image

    In the original experiments, was the BT loss value on the val set tracked every N epochs (without any classifier involved)? As far as I can tell, the loss reported in the 'val logs' txt file is of the classifier that was trained on top of the embeddings generating using a single training checkpoint.

    The loss curve on the val set should be expected to follow typical behavior, right? Any pointers as to what might be causing a noisy val loss?

    opened by neerajwagh 6
  • Barlow Twins loss on identical vector

    Barlow Twins loss on identical vector

    Hello, I really enjoyed reading the paper and thought about the intentions of the loss.

    However, I was wondering if setting the target matrix as identity matrix is eligible.

    As far as I understand, each element of cross correlation matrix is matrix multiplication on each feature element. Barlow Twins loss aims to have correlation of 1 on the diagonal and 0(no correlation) on the non-diagonal elements.

    So, if two identical representation vector were fed to the loss, I thought it should give loss of zero, but it didn't.

    For the sake of simplicity, let's say we have 2 pairs of representation vectors with identical values. (that's 4 vectors) However, when I take two identical 1d vectors for 2 data, took the batch norm and computed Barlow Twins loss with them, I got 1 on the diagonal but not 0 on the non-diagonal elements.

    Same thing goes for the case when the batch size is 1. (batch norm makes the value to be normalized to zero though)

    I'm not sure how the loss will learn invariance and redundancy with the target of identity matrix, especially on the redundancy term. Can you please elaborate on how representation vector learns within redundancy term?

    Here's a simple example I tried. (I followed the code implementation)

    image

    Thank you!

    opened by ChanLIM 6
  • Training on mutiple nodes

    Training on mutiple nodes

    Thanks for your great work. If there are two machines (each with 8 V100 GPUs) connected with ethernet, without slurm management, then how to run the code with your stated 16 V100 config?

    opened by d-li14 6
  • Question about Fig. 4 in the paper

    Question about Fig. 4 in the paper

    Hi. Thanks for the great work! I'm trying to reproduce your paper's results in Fig. 4 (effect of the dimensionality of the last layer of the projector network on performance). I have two questions about this:

    • Could you teach me what hyperparameters you used in the experiments?
    • Did you run main.py with --projector 8192-8192-16384? not --projector 16384-16384-16384, right?
    opened by yutaro-s 4
  • Applications on one-dimensional signal datasets

    Applications on one-dimensional signal datasets

    Hi, thank you for your work, it inspires me a lot. I want to apply BT to my data set about one-dimensional pulse signals instead of images, the data length is 500, do you think it needs to be changed to 224? I did the following work: I changed the number of channels in resnet50 from 3 to 1, and conv2d to conv1d, but the output dimension is still 2048. Good results can be obtained with supervised networks, but the loss of self-supervised learning processes does not drop significantly, and the downstream classification tasks are not as good as supervised networks. mlp also raised the dimension to 8192, but that didn't work either. What do you think might be the reason? Or can you give me some advice. Thanks again!

    opened by TQi-Yang 4
  • A question on the BT loss with Batch Norm layers

    A question on the BT loss with Batch Norm layers

    Hi, Thanks a lot for the very clear implementation and the paper is so easy to read!. I had a quick question on the Barlow Twins loss.

    Since Barlow Twins relies on the statistics of a batch of data, if there are batch norm layers in the encoder network, is it possible that the parameters of the BN layers be updated/affected more than any other parameters in the network to optimize the BT loss? Did you experiment without any batchnorm layers in the encoder to see if that affects the learned representations?

    opened by HareshKarnan 4
  • Why do we average out correlation matrices from different GPUs? Is this mathematically valid?

    Why do we average out correlation matrices from different GPUs? Is this mathematically valid?

    Thanks for this great work!

    I am a bit confused about the computation of the Barlow Twins loss in the multi-gpu setting. If I understand it correctly, each batch is split into smaller minibatches and these are then processed on separage GPUs. Each GPU computes the cross correlation matrix corresponding to its minibatch. The cross correlations between samples on different GPUs are not computed. It is not clear to me, why the different cross correlation matrices are averaged out across GPUs. This creates a mean correlation matrix and this one is then used for loss computation.

    Why not compute the loss for each correlation matrix separately and only average out the final loss? Or even better, why not compute the full cross correlation matrix (i.e. gather all embedding vectors onto one device and computing the cross correlation there?)

    I fail to see why summing up correlation matrices is a valid mathematical operation - or is it just an implementation "hack" that makes things easier? I guess since all cross correlation matrices are ideally converging towards identity matrices (as forced by the loss function), avereging them out does not strictly break the convergence - is that the case?

    I am not very experienced with distributed deep learning so there may be technical things I don't understand. Thanks for your help.

    https://github.com/facebookresearch/barlowtwins/blob/a655214c76c97d0150277b85d16e69328ea52fd9/main.py#L206-L223

    opened by radekd91 4
  • use barlowtwins predict but blocked and can not be killed

    use barlowtwins predict but blocked and can not be killed

    Hi, I'm very glad to thank your guys share this amazing project. it training very easy. but when I use it to test or predict, it blocked and can not be killed

    here is my predictor code:

    ` import torch import argparse from PIL import Image from torch import nn, optim from main import BarlowTwins, Transform

    """ BarlowTwins Predictor """

    def load_model(args): args.ngpus_per_node = torch.cuda.device_count() args.rank = 0 args.dist_url = 'tcp://localhost:58472' args.world_size = args.ngpus_per_node gpu = 0 model = BarlowTwins(args).cuda(gpu) model = nn.SyncBatchNorm.convert_sync_batchnorm(model) torch.distributed.init_process_group( backend='nccl', init_method=args.dist_url, world_size=args.world_size, rank=args.rank) model = torch.nn.parallel.DistributedDataParallel(model, device_ids=[gpu]) best_cp = torch.load("model/barlowtwins/bt_face.pth") # model.module.backbone.state_dict(best_cp) model.load_state_dict(best_cp) model.eval() return model

    def predict(t1, t2, model, device): model.to(device) with torch.no_grad(): t1=t1.to(torch.device("cuda:1" if torch.cuda.is_available() else "cpu")) t2=t2.to(torch.device("cuda:2" if torch.cuda.is_available() else "cpu")) out_data = model(t1, t2) return out_data

    def predict_img(img_url, args): img = Image.open(img_url) bt_trans = Transform() t1, t2 = bt_trans(img) t1 = t1.unsqueeze(0) t2 = t2.unsqueeze(0) model = load_model(args) device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") output = predict(t1, t2, model, device) return output

    if name=="main": parser = argparse.ArgumentParser(description='Barlow Twins Predict') parser.add_argument('--workers', default=8, type=int, metavar='N', help='number of data loader workers') parser.add_argument('--projector', default='8192-8192-8192', type=str, metavar='MLP', help='projector MLP') parser.add_argument("-d", default="data/test/1.jpg") args = parser.parse_args() img_url = str(args.d) result = predict_img(img_url, args) print(result) `

    would you please tell me which part is wrong? I'll very hopefully for your answer!

    opened by showkeyjar 4
  • Can we ignore the additional parameter --scale-loss?

    Can we ignore the additional parameter --scale-loss?

    hi, thanks for your contribution, I am guess that the additional parameter scale-loss cannot affect the gradient of loss, so in practice, we can ignore it, please tell me that i am right...

    opened by mitming 4
  • Pre-training model for CIFAR

    Pre-training model for CIFAR

    Hi,

    I am trying to train a pre-training model on the CIFAR-10 and CIFAR-100 dataset, but I could not achieve high top-1 classification accuracy while evaluating the pre-trained model using KNN. Does anyone have similar experience of using this official implementation for the CIFAR datasets?

    Thanks,

    opened by LiYunJamesPhD 0
  • providing the linear ImageNet classifier weights

    providing the linear ImageNet classifier weights

    are the learned linear classifier weights for ImageNet classification on torch hub? It seems that when I try to use the barlowtwins model there to classify, the performance is random so I'm assuming the fc layers in that model still have the random weights. Could you provide the linear classifier weights?

    Thank you!

    opened by nikparth 0
  • NaN's introduced during training.

    NaN's introduced during training.

    I'm seeing NaN's being introduced whilst training with the default configuration on ImageNet. The loss spiked suddenly, after which the loss was returned as NaN. This has occurred twice in a row. Similar issues exist in other facebookresearch repos https://github.com/facebookresearch/vissl/issues/543, the common link perhaps being the near identical implementation of the LARS optimiser. The linked issue suggests setting the optimiser to exclude bias and norm to avoid NaNs, however this is different to the published BarlowTwins method.

    EDIT: The BarlowTwins paper does in fact state that the bias and batch normalization parameters are excluded from the LARS optimiser and this is the case in the code. Perhaps given the spike, this is an exploding gradients problem?

    Loss graph:

    image

    opened by charliebudd 0
  • Quality of Embeddings

    Quality of Embeddings

    Hello,

    thanks for your publishing your amazing work. Have you done any investigations on the quality of the embeddings for downstream tasks rather than the representations? I think the SimCLR paper did some, where they showed that the representations were better, but since the embeddings are nicely disentangled in your work, I was wondering if barlowTwins' embeddings might be better, especially when disentanglement is needed?

    Greetings

    opened by Gnabe 0
  • Add dlib implementation to the community updates

    Add dlib implementation to the community updates

    Hi, I added an implementation of the Barlow Twins loss to dlib, and it's now part of dlib, since v19.23.

    Again, great paper! It was a pleasure to read and implement.

    CLA Signed 
    opened by arrufat 2
Owner
Facebook Research
Facebook Research
Implementation of Barlow Twins paper

barlowtwins PyTorch Implementation of Barlow Twins paper: Barlow Twins: Self-Supervised Learning via Redundancy Reduction This is currently a work in

IgorSusmelj 86 Dec 20, 2022
Barlow Twins and HSIC

Barlow Twins and HSIC Unofficial Pytorch implementation for Barlow Twins and HSIC_SSL on small datasets (CIFAR10, STL10, and Tiny ImageNet). Correspon

Yao-Hung Hubert Tsai 49 Nov 24, 2022
Twins: Revisiting the Design of Spatial Attention in Vision Transformers

Twins: Revisiting the Design of Spatial Attention in Vision Transformers Very recently, a variety of vision transformer architectures for dense predic

null 482 Dec 18, 2022
ALBERT-pytorch-implementation - ALBERT pytorch implementation

ALBERT-pytorch-implementation developing... 모델의 개념이해를 돕기 위한 구현물로 현재 변수명을 상세히 적었고

BG Kim 3 Oct 6, 2022
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 8, 2022
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

LEI TAI 111 Dec 8, 2022
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023
Fang Zhonghao 13 Nov 19, 2022
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 4, 2023
HashNeRF-pytorch - Pure PyTorch Implementation of NVIDIA paper on Instant Training of Neural Graphics primitives

HashNeRF-pytorch Instant-NGP recently introduced a Multi-resolution Hash Encodin

Yash Sanjay Bhalgat 616 Jan 6, 2023
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 520 Dec 30, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

NVIDIA Corporation 6.9k Jan 3, 2023
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 1, 2023
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 360 Dec 10, 2022
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 2, 2023