Barlow Twins and HSIC

Overview

Barlow Twins and HSIC

Unofficial Pytorch implementation for Barlow Twins and HSIC_SSL on small datasets (CIFAR10, STL10, and Tiny ImageNet).

Correspondence to:

Technical Report

A Note on Connecting Barlow Twins with Negative-Samples-Free Contrastive Learning
Yao-Hung Hubert Tsai, Shaojie Bai, Louis-Philippe Morency, and Ruslan Salakhutdinov

I hope this work will be useful for your research 🥰

Usage

Disclaimer

A large portion of the code is from this repo, which is a great resource for academic development. Note that we do not perform extensive hyper-parameters grid search and hence you may expect a performance boost after tuning some hyper-parameters (e.g., the learning rate).

The official implementation of Barlow Twins can be found here. We have also tried the HSIC_SSL in this official repo and we find similar performance (we tried on ImageNet-1K and CIFAR10) between HSIC_SSL and Barlow Twins' method.

Requirements

conda install pytorch torchvision cudatoolkit=10.0 -c pytorch
  • thop
pip install thop

Supported Dataset

CIFAR10, STL10, and Tiny_ImageNet.

Train and Linear Evaluation using Barlow Twins

python main.py --lmbda 0.0078125 --corr_zero --batch_size 128 --feature_dim 128 --dataset cifar10
python linear.py --dataset cifar10 --model_path results/0.0078125_128_128_cifar10_model.pth

Train and Linear Evaluation using HSIC

python main.py --lmbda 0.0078125 --corr_neg_one --batch_size 128 --feature_dim 128 --dataset cifar10
python linear.py --dataset cifar10 --model_path results/neg_corr_0.0078125_128_128_cifar10_model.pth
Comments
  • Speed Up Model by Using cudnn benchmark

    Speed Up Model by Using cudnn benchmark

    Not really an issue, but adding the following made training significantly faster (+29% on Titan X Pascal)

    if torch.cuda.is_available():
        torch.backends.cudnn.benchmark = True
    

    I thought I would mention it as others might benefit from this as well.

    Thanks for making your code available on GitHub!

    opened by davidrzs 1
  • Inaccuracy in the cross correlation for small batch sizes (<32)

    Inaccuracy in the cross correlation for small batch sizes (<32)

    Hey,

    first, thanks for sharing your research and code!

    TL;DR

    Your code uses torch.std which uses Bessel's correction by default, therefore inhibiting that the values on the diagonal reach 1.

    While working with it, I noticed some small inaccuracy for in the calculation of the cross-correlation matrix.

    Opposed to original implementation, which uses BatchNorm1d you implemented the normalization with:

    https://github.com/yaohungt/Barlow-Twins-HSIC/blob/a30baba4d2d3dcdf85a4ccebe57902f0d827c1ed/main.py#L36-L38

    I implemented a small test with two identical vectors coming from the projection head and was therefore expecting straight ones on the diagonal. But as you can see from my attached code, for a batch size of < 32 (here 8), the values on the diagonal can't get bigger than 0.75. I found that torch.std uses Bessel's correction by default. When this flag is set to false, the numbers match with the original implementation.

    I think there is no practical difference for batch sizes > 32, which is also the smallest batch size you presented in your paper, I think.

    import torch
    from torch import nn
    
    batch_size = 4
    size_z = 128
    
    torch.manual_seed(1234)
    z1 = torch.randn(batch_size, size_z)
    z2 = z1.clone()
    
    # your implementation
    z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0)
    z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0)
    cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size
    
    print( cross_corr[0:5,0:5] )
    
    # tensor([[ 0.7500, -0.1065, -0.0837,  0.0630,  0.3664],
    #        [-0.1065,  0.7500, -0.2283, -0.3708, -0.5607],
    #        [-0.0837, -0.2283,  0.7500, -0.5013, -0.2554],
    #        [ 0.0630, -0.3708, -0.5013,  0.7500,  0.6334],
    #        [ 0.3664, -0.5607, -0.2554,  0.6334,  0.7500]])
    
    # original implementation
    bn = nn.BatchNorm1d(size_z, affine=False)
    z1_norm = bn(z1)
    z2_norm = bn(z2)
    cross_corr = z1_norm.T @ z2_norm / batch_size
    
    print( cross_corr[0:5,0:5] )
    
    # tensor([[ 1.0000, -0.1420, -0.1116,  0.0840,  0.4885],
    #        [-0.1420,  1.0000, -0.3043, -0.4944, -0.7476],
    #        [-0.1116, -0.3043,  1.0000, -0.6683, -0.3405],
    #        [ 0.0840, -0.4944, -0.6683,  1.0000,  0.8445],
    #        [ 0.4885, -0.7476, -0.3405,  0.8445,  1.0000]])
    
    # corrected code (without Bessel’s correction for the calculation of the standard deviation)
    z1_norm = (z1 - torch.mean(z1, dim=0)) / torch.std(z1, dim=0, unbiased=False)
    z2_norm = (z2 - torch.mean(z2, dim=0)) / torch.std(z2, dim=0, unbiased=False)
    cross_corr = torch.matmul(z1_norm.T, z2_norm) / batch_size
    
    print( cross_corr[0:5,0:5] )
    
    # tensor([[ 1.0000, -0.1421, -0.1116,  0.0840,  0.4885],
    #         [-0.1421,  1.0000, -0.3043, -0.4944, -0.7476],
    #         [-0.1116, -0.3043,  1.0000, -0.6683, -0.3405],
    #         [ 0.0840, -0.4944, -0.6683,  1.0000,  0.8446],
    #         [ 0.4885, -0.7476, -0.3405,  0.8446,  1.0000]])
    
    opened by JohannesK14 0
  • Result about tiny-imagenet

    Result about tiny-imagenet

    hi, have you run the model with the tiny-imagenet? can you tell the result on the tiny-imagenet as I try to run to the model on tiny-imagenet but it seems too low?

    opened by czzerone 0
  • Question re: reproducing Fig 2 from the paper

    Question re: reproducing Fig 2 from the paper

    Hello --

    I'm interested in trying to reproduce the Barlow Twins curve from Fig 2 in the paper.

    I'm running:

    python main.py --lmbda 0.0078125 --corr_zero --batch_size 128 --feature_dim 128 --dataset cifar10
    

    and getting:

    Test Epoch: [5/1000] Acc@1:47.33% Acc@5:92.49%
    Test Epoch: [10/1000] Acc@1:53.80% Acc@5:94.87%
    Test Epoch: [15/1000] Acc@1:58.44% Acc@5:96.68%
    Test Epoch: [20/1000] Acc@1:63.11% Acc@5:96.86%
    Test Epoch: [25/1000] Acc@1:65.55% Acc@5:97.33%
    Test Epoch: [30/1000] Acc@1:66.59% Acc@5:97.61%
    Test Epoch: [35/1000] Acc@1:68.85% Acc@5:97.87%
    Test Epoch: [40/1000] Acc@1:69.17% Acc@5:97.75%
    Test Epoch: [45/1000] Acc@1:71.24% Acc@5:98.16%
    Test Epoch: [50/1000] Acc@1:72.38% Acc@5:98.26%
    

    In Fig 2, it looks like accuracy after 50 epochs should be ~ 79%, but I'm only getting to ~72%.

    Any ideas why there might be a gap? Perhaps the accuracies reported in Fig 2 are from training a linear classifier (eg, in linear.py) rather than using the weighted KNN in main.py:train?

    Thanks!

    opened by bkj 1
  • The feature normalization is necessary?

    The feature normalization is necessary?

    https://github.com/yaohungt/Barlow-Twins-HSIC/blob/a30baba4d2d3dcdf85a4ccebe57902f0d827c1ed/model.py#L31

    Hi, your code is very helpful and I want to firstly appreciate the code share.

    I have a question on whether this feature normalization is necessary (to make the cifar10 performance to about 92% accuracy).

    The original Barlow Twins does not contain this step. On the other hand, they rather define all linear layers in the projector with no bias.

    opened by le4m 0
  • Question Regarding Tranform

    Question Regarding Tranform

    The implementation is simple and easy to use. Thank you for that. I have one doubt,

    Given a mini batch with input x of size BxCxHxW

    we apply transformations to get y1 = self.transform(x) y2 = self.transform(x)

    So is this a batch transformation or image wise transformation

    Because as per the paper "More specifically, it produces two distorted views for all images of a batch X sampled from a dataset" there are two distorted views only i interpret it as for one distorted view we apply the same transformation for the images in a batch

    opened by lab176344 0
Owner
Yao-Hung Hubert Tsai
I'm currently a Ph.D. Student in Machine Learning Department, Carnegie Mellon University.
Yao-Hung Hubert Tsai
Implementation of Barlow Twins paper

barlowtwins PyTorch Implementation of Barlow Twins paper: Barlow Twins: Self-Supervised Learning via Redundancy Reduction This is currently a work in

IgorSusmelj 86 Dec 20, 2022
PyTorch implementation of Barlow Twins.

Barlow Twins: Self-Supervised Learning via Redundancy Reduction PyTorch implementation of Barlow Twins. @article{zbontar2021barlow, title={Barlow Tw

Facebook Research 839 Dec 29, 2022
Twins: Revisiting the Design of Spatial Attention in Vision Transformers

Twins: Revisiting the Design of Spatial Attention in Vision Transformers Very recently, a variety of vision transformer architectures for dense predic

null 482 Dec 18, 2022
Pre-trained BERT Models for Ancient and Medieval Greek, and associated code for LaTeCH 2021 paper titled - "A Pilot Study for BERT Language Modelling and Morphological Analysis for Ancient and Medieval Greek"

Ancient Greek BERT The first and only available Ancient Greek sub-word BERT model! State-of-the-art post fine-tuning on Part-of-Speech Tagging and Mor

Pranaydeep Singh 22 Dec 8, 2022
An image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testingAn image base contains 490 images for learning (400 cars and 90 boats), and another 21 images for testing

SVM Données Une base d’images contient 490 images pour l’apprentissage (400 voitures et 90 bateaux), et encore 21 images pour fait des tests. Prétrait

Achraf Rahouti 3 Nov 30, 2021
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 23.6k Dec 31, 2022
Python Library for learning (Structure and Parameter) and inference (Statistical and Causal) in Bayesian Networks.

pgmpy pgmpy is a python library for working with Probabilistic Graphical Models. Documentation and list of algorithms supported is at our official sit

pgmpy 2.2k Jan 3, 2023
High performance, easy-to-use, and scalable machine learning (ML) package, including linear model (LR), factorization machines (FM), and field-aware factorization machines (FFM) for Python and CLI interface.

What is xLearn? xLearn is a high performance, easy-to-use, and scalable machine learning package that contains linear model (LR), factorization machin

Chao Ma 3k Jan 3, 2023
This repository is related to an Arabic tutorial, within the tutorial we discuss the common data structure and algorithms and their worst and best case for each, then implement the code using Python.

Data Structure and Algorithms with Python This repository is related to the Arabic tutorial here, within the tutorial we discuss the common data struc

Mohamed Ayman 33 Dec 2, 2022
Scalable, Portable and Distributed Gradient Boosting (GBDT, GBRT or GBM) Library, for Python, R, Java, Scala, C++ and more. Runs on single machine, Hadoop, Spark, Dask, Flink and DataFlow

eXtreme Gradient Boosting Community | Documentation | Resources | Contributors | Release Notes XGBoost is an optimized distributed gradient boosting l

Distributed (Deep) Machine Learning Community 20.6k Feb 13, 2021
High performance, easy-to-use, and scalable machine learning (ML) package, including linear model (LR), factorization machines (FM), and field-aware factorization machines (FFM) for Python and CLI interface.

What is xLearn? xLearn is a high performance, easy-to-use, and scalable machine learning package that contains linear model (LR), factorization machin

Chao Ma 2.8k Feb 12, 2021
BBB streaming without Xorg and Pulseaudio and Chromium and other nonsense (heavily WIP)

BBB Streamer NG? Makes a conference like this... ...streamable like this! I also recorded a small video showing the basic features: https://www.youtub

Lukas Schauer 60 Oct 21, 2022
All the essential resources and template code needed to understand and practice data structures and algorithms in python with few small projects to demonstrate their practical application.

Data Structures and Algorithms Python INDEX 1. Resources - Books Data Structures - Reema Thareja competitiveCoding Big-O Cheat Sheet DAA Syllabus Inte

Shushrut Kumar 129 Dec 15, 2022
Implement face detection, and age and gender classification, and emotion classification.

YOLO Keras Face Detection Implement Face detection, and Age and Gender Classification, and Emotion Classification. (image from wider face dataset) Ove

Chloe 10 Nov 14, 2022
PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos

PyKale is a PyTorch library for multimodal learning and transfer learning as well as deep learning and dimensionality reduction on graphs, images, texts, and videos. By adopting a unified pipeline-based API design, PyKale enforces standardization and minimalism, via reusing existing resources, reducing repetitions and redundancy, and recycling learning models across areas.

PyKale 370 Dec 27, 2022
[CVPR 21] Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, IEEE Conf. on Computer Vision and Pattern Recognition (CVPR), 2021.

Vectorization and Rasterization: Self-Supervised Learning for Sketch and Handwriting, CVPR 2021. Ayan Kumar Bhunia, Pinaki nath Chowdhury, Yongxin Yan

Ayan Kumar Bhunia 44 Dec 12, 2022