Official repository for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization".

Overview

StableNet

StableNet is a deep stable learning method for out-of-distribution generalization.

This is the official repo for CVPR21 paper "Deep Stable Learning for Out-Of-Distribution Generalization" and the arXiv version can be found at https://arxiv.org/abs/2104.07876.

Introduction

Approaches based on deep neural networks have achieved striking performance when testing data and training data share similar distribution, but can significantly fail otherwise. Therefore, eliminating the impact of distribution shifts between training and testing data is crucial for building performance-promising deep models. Conventional methods assume either the known heterogeneity of training data (e.g. domain labels) or the approximately equal capacities of different domains. In this paper, we consider a more challenging case where neither of the above assumptions holds. We propose to address this problem by removing the dependencies between features via learning weights for training samples, which helps deep models get rid of spurious correlations and, in turn, concentrate more on the true connection between discriminative features and labels. Extensive experiments clearly demonstrate the effectiveness of our method on multiple distribution generalization benchmarks compared with state-of-the-art counterparts. Through extensive experiments on distribution generalization benchmarks including PACS, VLCS, MNIST-M, and NICO, we show the effectiveness of our method compared with state-of-the-art counterparts.

Installation

Requirements

  • Linux with Python >= 3.6
  • PyTorch >= 1.1.0
  • torchvision >= 0.3.0
  • tensorboard >= 1.14.0

Quick Start

Train StableNet

python main_stablenet.py --gpu 0

You can see more options from

python main_stablenet.py -h

Result files will be saved in results/.

Performance and trained models

setting dataset source domain target domain network dataset split accuracy trained model
unbalanced(5:1:1) PACS A,C,S photo ResNet18 split file 94.864 model file
unbalanced(5:1:1) PACS C,S,P art_painting ResNet18 split file 80.344 model file
unbalanced(5:1:1) PACS A,S,P cartoon ResNet18 split file 74.249 model file
unbalanced(5:1:1) PACS A,C,P sketch ResNet18 split file 71.046 model file
unbalanced(5:1:1) VLCS L,P,S caltech ResNet18 split file 88.776 model file
unbalanced(5:1:1) VLCS C,P,S labelme ResNet18 split file 63.243 model file
unbalanced(5:1:1) VLCS C,L,S pascal ResNet18 split file 66.383 model file
unbalanced(5:1:1) VLCS C,L,P sun ResNet18 split file 55.459 model file
flexible(5:1:1) PACS - - ResNet18 split file 45.964 model file
flexible(5:1:1) VLCS - - ResNet18 split file 81.157 model file

Citing StableNet

If you find this repo useful for your research, please consider citing the paper.

@inproceedings{zhang2021deep,
  title={Deep Stable Learning for Out-Of-Distribution Generalization},
  author={Zhang, Xingxuan and Cui, Peng and Xu, Renzhe and Zhou, Linjun and He, Yue and Shen, Zheyan},
  booktitle={Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition},
  pages={5372--5382},
  year={2021}
}
Comments
  • Seems like a mismatch between implementation code & mathematical formulation in the paper

    Seems like a mismatch between implementation code & mathematical formulation in the paper

    My understanding:

    • RFF: random_fourier_features_gpu() adds a new dimension of size num_f rather than performs dimension transformation.
    • lossb_expect(cfeaturec, num_f) computes the square of the Frobenius norm of the original feature vector (of size batch_size * feature_dimemsion ) at each num_f dimension, then accumulating them minus the trace of covariance matrix gives the final loss.

    Question:

    • Does cfeaturec refer to A_i and B_i, and n the batch size in Eq. (3)?
    • num_f refer to n_A and n_B in Eq. (4)? image
    opened by Buhua-Liu 1
  • Some Questions about loss_reweighting.py

    Some Questions about loss_reweighting.py

    Hello, author, in the cov function of loss_reweighting.py. I think there is a little different with the definition in the paper: Your code: cov = torch.matmul((w * x).t(), x) Your paper: T}USG}2ICE}E02SL 84(1O1 Why the code is not cov = torch.matmul((w * x).t(), w * x) Thank you very much.

    opened by believewhat 0
  • Reproduction about NICO dataset

    Reproduction about NICO dataset

    Thanks for your great work!

    In your paper, you introduced results for NICO dataset. But in this repo, there is no dataset split file for it. I try to split the dataset according to the description in the paper. Then I try to reproduce experiments of baseline ResNet-18 and StableNet.

    The results show that best accuracy of baseline ResNet-18 is 47.71 while in paper it is 51.71. The gap seems small. However, the best accuracy of StableNet is 48.20 while in paper it is 59.76, which is confusing.

    I know there are some variance about randomness of data split and difference of hyperparameter tuning. Could you please provide the dataset split file of NICO and recommended hyperparameter setting for it? Thank you!

    opened by Gaohan123 5
  • Problem with reproduction

    Problem with reproduction

    Hi~ Sorry to disturb the place, please forgive me I would like to communicate with you about some problems encountered in the reproduction process. The problem is: FileNotFoundError: [WinError 3] 系统找不到指定的路径。: '/DATA/DATANAS1/windxrz/dataset/PACS/split_compositional_with_val_sketch\train' image

    I downloaded README's split file but couldn't find it. image

    opened by xia-xia-xia 2
  • About the decay factor $\alpha_i$

    About the decay factor $\alpha_i$

    Hi, In the paper, it is said that the $\alpha_i$ is different when fusing global information $Z_{G_i}$ (in Eq.10). However, I found that $\alpha_i$'s are actually the same (in the reweighting.py):

            pre_features = pre_features * args.presave_ratio + cfeatures * (1 - args.presave_ratio)
            pre_weight1 = pre_weight1 * args.presave_ratio + weight * (1 - args.presave_ratio)
    

    Have I missed anything? Tks.

    opened by albertcity 0
  • covariance matrix with w or without w

    covariance matrix with w or without w

    Hello, author, in the cov method of loss_reweighting.py file in your code, you define the operation of the covariance matrix with and without w, but it seems that the calculation of w does not match the definition in your paper, I do not know whether it is a code error or I understand the error. QQ截图20211124204138

    opened by ScorpioBao 1
  • can't reproduce the result

    can't reproduce the result

    Hi ~ Thank you for your excellent work. I've been working on this recently, but I can't exactly reproduce the results in the paper through your open-source code. The dataset partitioning in my experiment refers to split files, and then all parameters are set by default. I would like to know if the results in your paper are averaged over the last few epochs or if the best results are obtained. And whether the experimental parameter settings are exactly the same for PACS and VLCS datasets? Looking forward to your reply!

    opened by challow0 2
Owner
null
Code for HLA-Face: Joint High-Low Adaptation for Low Light Face Detection (CVPR21)

HLA-Face: Joint High-Low Adaptation for Low Light Face Detection The official PyTorch implementation for HLA-Face: Joint High-Low Adaptation for Low L

Wenjing Wang 77 Dec 8, 2022
[CVPR21] LightTrack: Finding Lightweight Neural Network for Object Tracking via One-Shot Architecture Search

LightTrack: Finding Lightweight Neural Networks for Object Tracking via One-Shot Architecture Search The official implementation of the paper LightTra

Multimedia Research 290 Dec 24, 2022
Released code for Objects are Different: Flexible Monocular 3D Object Detection, CVPR21

MonoFlex Released code for Objects are Different: Flexible Monocular 3D Object Detection, CVPR21. Work in progress. Installation This repo is tested w

Yunpeng 169 Dec 6, 2022
Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology

Official repository for the ICLR 2021 paper Evaluating the Disentanglement of Deep Generative Models with Manifold Topology Sharon Zhou, Eric Zelikman

Stanford Machine Learning Group 34 Nov 16, 2022
The repository offers the official implementation of our paper in PyTorch.

Cloth Interactive Transformer (CIT) Cloth Interactive Transformer for Virtual Try-On Bin Ren1, Hao Tang1, Fanyang Meng2, Runwei Ding3, Ling Shao4, Phi

Bingoren 49 Dec 1, 2022
Official code repository of the paper Learning Associative Inference Using Fast Weight Memory by Schlag et al.

Learning Associative Inference Using Fast Weight Memory This repository contains the offical code for the paper Learning Associative Inference Using F

Imanol Schlag 18 Oct 12, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

selfcontact This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] It includes the main function

Lea Müller 68 Dec 6, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

SMPLify-XMC This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] License Software Copyright Lic

Lea Müller 83 Dec 14, 2022
Official repository for the paper "Going Beyond Linear Transformers with Recurrent Fast Weight Programmers"

Recurrent Fast Weight Programmers This is the official repository containing the code we used to produce the experimental results reported in the pape

IDSIA 36 Nov 15, 2022
Official repository for the paper "Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks"

Easy-To-Hard The official repository for the paper "Can You Learn an Algorithm? Generalizing from Easy to Hard Problems with Recurrent Networks". Gett

Avi Schwarzschild 52 Sep 8, 2022
Official repository for the CVPR 2021 paper "Learning Feature Aggregation for Deep 3D Morphable Models"

Deep3DMM Official repository for the CVPR 2021 paper Learning Feature Aggregation for Deep 3D Morphable Models. Requirements This code is tested on Py

null 38 Dec 27, 2022
Official repository for the paper, MidiBERT-Piano: Large-scale Pre-training for Symbolic Music Understanding.

MidiBERT-Piano Authors: Yi-Hui (Sophia) Chou, I-Chun (Bronwin) Chen Introduction This is the official repository for the paper, MidiBERT-Piano: Large-

null 137 Dec 15, 2022
Official repository of the paper 'Essentials for Class Incremental Learning'

Essentials for Class Incremental Learning Official repository of the paper 'Essentials for Class Incremental Learning' This Pytorch repository contain

null 33 Nov 27, 2022
This repository is an official implementation of the paper MOTR: End-to-End Multiple-Object Tracking with TRansformer.

MOTR: End-to-End Multiple-Object Tracking with TRansformer This repository is an official implementation of the paper MOTR: End-to-End Multiple-Object

null 348 Jan 7, 2023
Official Repository for the ICCV 2021 paper "PixelSynth: Generating a 3D-Consistent Experience from a Single Image"

PixelSynth: Generating a 3D-Consistent Experience from a Single Image (ICCV 2021) Chris Rockwell, David F. Fouhey, and Justin Johnson [Project Website

Chris Rockwell 95 Nov 22, 2022
Official repository with code and data accompanying the NAACL 2021 paper "Hurdles to Progress in Long-form Question Answering" (https://arxiv.org/abs/2103.06332).

Hurdles to Progress in Long-form Question Answering This repository contains the official scripts and datasets accompanying our NAACL 2021 paper, "Hur

Kalpesh Krishna 41 Nov 8, 2022
This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Information Maximization for Multimodal Sentiment Analysis, accepted at EMNLP 2021.

MultiModal-InfoMax This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Informa

Deep Cognition and Language Research (DeCLaRe) Lab 89 Dec 26, 2022
CVPR 2021 - Official code repository for the paper: On Self-Contact and Human Pose.

TUCH This repo is part of our project: On Self-Contact and Human Pose. [Project Page] [Paper] [MPI Project Page] License Software Copyright License fo

Lea Müller 45 Jan 7, 2023