Official Implementation of SWAD (NeurIPS 2021)

Related tags

Deep Learning swad
Overview

SWAD: Domain Generalization by Seeking Flat Minima (NeurIPS'21)

Official PyTorch implementation of SWAD: Domain Generalization by Seeking Flat Minima.

Junbum Cha, Sanghyuk Chun, Kyungjae Lee, Han-Cheol Cho, Seunghyun Park, Yunsung Lee, Sungrae Park.

Note that this project is built upon DomainBed@3fe9d7.

Preparation

Dependencies

pip install -r requirements.txt

Datasets

python -m domainbed.scripts.download --data_dir=/my/datasets/path

Environments

Environment details used for our study.

Python: 3.8.6
PyTorch: 1.7.0+cu92
Torchvision: 0.8.1+cu92
CUDA: 9.2
CUDNN: 7603
NumPy: 1.19.4
PIL: 8.0.1

How to Run

train_all.py script conducts multiple leave-one-out cross-validations for all target domain.

python train_all.py exp_name --dataset PACS --data_dir /my/datasets/path

Experiment results are reported as a table. In the table, the row SWAD indicates out-of-domain accuracy from SWAD. The row SWAD (inD) indicates in-domain validation accuracy.

Example results:

+------------+--------------+---------+---------+---------+---------+
| Selection  | art_painting | cartoon |  photo  |  sketch |   Avg.  |
+------------+--------------+---------+---------+---------+---------+
|   oracle   |   82.245%    | 85.661% | 97.530% | 83.461% | 87.224% |
|    iid     |   87.919%    | 78.891% | 96.482% | 78.435% | 85.432% |
|    last    |   82.306%    | 81.823% | 95.135% | 82.061% | 85.331% |
| last (inD) |   95.807%    | 95.291% | 96.306% | 95.477% | 95.720% |
| iid (inD)  |   97.275%    | 96.619% | 96.696% | 97.253% | 96.961% |
|    SWAD    |   89.750%    | 82.942% | 97.979% | 81.870% | 88.135% |
| SWAD (inD) |   97.713%    | 97.649% | 97.316% | 98.074% | 97.688% |
+------------+--------------+---------+---------+---------+---------+

In this example, the DG performance of SWAD for PACS dataset is 88.135%.

If you set indomain_test option to True, the validation set is splitted to validation and test sets, and the (inD) keys become to indicate in-domain test accuracy.

Reproduce the results of the paper

We provide the instructions to reproduce the main results of the paper, Table 1 and 2. Note that the difference in a detailed environment or uncontrolled randomness may bring a little different result from the paper.

  • PACS
python train_all.py PACS0 --dataset PACS --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py PACS1 --dataset PACS --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py PACS2 --dataset PACS --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • VLCS
python train_all.py VLCS0 --dataset VLCS --deterministic --trial_seed 0 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
python train_all.py VLCS1 --dataset VLCS --deterministic --trial_seed 1 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
python train_all.py VLCS2 --dataset VLCS --deterministic --trial_seed 2 --checkpoint_freq 50 --tolerance_ratio 0.2 --data_dir /my/datasets/path
  • OfficeHome
python train_all.py OH0 --dataset OfficeHome --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py OH1 --dataset OfficeHome --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py OH2 --dataset OfficeHome --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • TerraIncognita
python train_all.py TR0 --dataset TerraIncognita --deterministic --trial_seed 0 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py TR1 --dataset TerraIncognita --deterministic --trial_seed 1 --checkpoint_freq 100 --data_dir /my/datasets/path
python train_all.py TR2 --dataset TerraIncognita --deterministic --trial_seed 2 --checkpoint_freq 100 --data_dir /my/datasets/path
  • DomainNet
python train_all.py DN0 --dataset DomainNet --deterministic --trial_seed 0 --checkpoint_freq 500 --data_dir /my/datasets/path
python train_all.py DN1 --dataset DomainNet --deterministic --trial_seed 1 --checkpoint_freq 500 --data_dir /my/datasets/path
python train_all.py DN2 --dataset DomainNet --deterministic --trial_seed 2 --checkpoint_freq 500 --data_dir /my/datasets/path

Main Results

Citation

The paper will be published at NeurIPS 2021.

@inproceedings{cha2021swad,
  title={SWAD: Domain Generalization by Seeking Flat Minima},
  author={Cha, Junbum and Chun, Sanghyuk and Lee, Kyungjae and Cho, Han-Cheol and Park, Seunghyun and Lee, Yunsung and Park, Sungrae},
  booktitle={Advances in Neural Information Processing Systems (NeurIPS)},
  year={2021}
}

License

This source code is released under the MIT license, included here.

This project includes some code from DomainBed, also MIT licensed.

Comments
  • Question about how hyperparameters change as test domain and dataset changes

    Question about how hyperparameters change as test domain and dataset changes

    Thank you for your great work.

    I'm little curious about how hyperparameters change as test domain changes in your code.

    I saw the instructions you provided in the readme for reproducing your results, but I don't understand gow the HPs change as the test domain changes.

    In your instructions (command lines for reproducing results), no other arguments for HPs is given and I think all the experiments will use default HPsregardless of test domain and dataset.

    I'm just curious how HPs change as the test domain and dataset change in your instructions.

    Thank you.


    korean ver. 안녕하세요 좋은 연구 감사드립니다.

    Domain generalization 관련 연구에 관심이 있어 공부를 하던 중 SWAD 논문을 보게 되었고, 이를 활용하고 싶어, 코드를 보던 중 질문이 있어 이슈를 남기게 되었습니다.

    Readme에서 결과는 재현하기 위해 필요한 커맨드 라인을 제공해주셨는데, 해당 커멘드 라인으로 실험을 돌릴 시 dataset과 test domain이 바뀜에 따라 HPs가 어떻게 바뀌는지 이해가 가지 않아 질문을 드립니다.

    해당 커멘드 라인에는 HP와 관련된 argument가 제공되지 않는 것으로 보이는데, 이럴 경우에는 모든 실험이 default HPs를 사용하여 진행될 것 같습니다.

    dataset, test domain이 바뀜에 따라 사용되는 (찾아진) HP가 다를 것 같은데, 제공된 커멘드라인을 사용하여 코드를 돌렸을 경우 dataset, test domain이 바뀜에 따라 HP가 어떻게 바뀌는지 궁금하여 질문드립니다.

    감사합니다.

    opened by wlaud1001 6
  • How to draw flatness curve in Figure 3?

    How to draw flatness curve in Figure 3?

    Hi, Thank you so much for providing this repo, the work is awesome! And how can we reproduce the loss gap curve in Figure 3 of this paper? How to add the gamma on the model parameter and what is the metric of the distance in X-axis? I flat the model parameter dict into one vector and add a noise vector with norm 1.0 and get the loss gap about 0.2 on p domain test, I must have made a mistake on the Monte-Carlo approximation sampling. Thanks a lot!

    opened by FrankZhangRp 5
  • How to train on multi-GPU?

    How to train on multi-GPU?

    Hello, because my hardware is limited, there is not enough memory to test on the domain net dataset, so parallel multi-GPU training is required, but it cannot be done according to the code algorithm = torch.nn.DataParallel(algorithm,device_ids=range(torch.cuda.device_count()))`` Do you have any good suggestions? thanks!

    opened by shuangliumax 4
  • Hyperparameter search protocol

    Hyperparameter search protocol

    Hi,

    Thanks for providing this great repo! I have one question regarding the hyperparameter search protocol. In section B.2 in the paper (Table 7), you indicate that the search space is constrained compared to original DomainBed, but there are still random choices on learning rate, weight decay, etc. However, if I understand correctly, the current implementation only uses the default hparams without randomness involved. The only randomness comes from different trails. Is this correct? If so, how is the original experiments conducted in the paper?

    Further, how do you compute this 396?

    Through the proposed protocol, we find HP for an algorithm under only 396 runs.

    Thanks in advance.

    opened by optharry 4
  • Can't reproduce the result of ERM.

    Can't reproduce the result of ERM.

    train PACS and Office Home one time, and the result is

    Office Home: Last: 62.4, SWAD: 70.8 , Oracle: 66.3 PACS: Last: 82.2, SWAD: 88.3, Oracle: 88.1

    The Last result if ERM? It is not consistent with the result of your paper. So the ERM result is Oracle?

    opened by justopit 3
  • Question about the table

    Question about the table

    image Hi, I want to know if every column indicates leave-one-out cross-validations for the target domain accuracy. And oracle uses the small split of the test domain as the validation set (test-domain validation). Is this contradictory to leave-one-out?

    opened by Cassiatora 2
  • Are all algorithms on DomainNet reproduces by 15k iterations?

    Are all algorithms on DomainNet reproduces by 15k iterations?

    Hi thanks for providing such great work to us. I noticed that MIRO and SWAD results on DomainNet are increased from 5k to 15k iterations. I am curious are all other algorithms in this table reproduced by 15k iterations?

    image
    opened by Luodian 2
  • Loss surface plots

    Loss surface plots

    Hi,

    Thank you so much for providing this repo, the code looks great.

    How can one reproduce the loss surface plots, e.g., Figure 4 in the paper?

    Did you use this code here?

    opened by JeanKaddour 2
  • about terra_incognita dataset

    about terra_incognita dataset

    Hello, when I downloaded the "terra_incognita" dataset according to the download.py , I tried to test it , but such an error occurred. RuntimeError: CUDA error: device-side assert triggered /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [0,0,0] Assertion t >= 0 && t < n_classes failed. /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [2,0,0] Assertion t >= 0 && t < n_classes failed. /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [3,0,0] Assertion t >= 0 && t < n_classes failed. /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [4,0,0] Assertion t >= 0 && t < n_classes failed. /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [5,0,0] Assertion t >= 0 && t < n_classes failed. /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [6,0,0] Assertion t >= 0 && t < n_classes failed. /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [7,0,0] Assertion t >= 0 && t < n_classes failed. /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [8,0,0] Assertion t >= 0 && t < n_classes failed. /opt/conda/conda-bld/pytorch_1607370131125/work/aten/src/THCUNN/ClassNLLCriterion.cu:108: cunn_ClassNLLCriterion_updateOutput_kernel: block: [0,0,0], thread: [9,` this error I checked may be the problem of label. But I don't know how to correct it. Do you have any good suggestions? Thank you!

    opened by shuangliumax 2
  • Questions about oracle metric

    Questions about oracle metric

    Dear authors,

    I'm not familiar with Domain Generalization. In the reported table. you have also provided oracle, iid, last metrics. What's those metrics meaning? Could you also give some explains?

    Thanks in advance.

    Regards,

    opened by zhihou7 2
  • download_office_home link error

    download_office_home link error

    Hi,

    The link to download OfficeHome is forbiddened. Is the dataset downloaded from the original URL the same as the dataset downloaded from your provided URL?

    Best regards.

    opened by ShijianXu 1
Owner
Junbum Cha
Junbum Cha
Official Pytorch implementation of "Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021)

Unbiased Classification Through Bias-Contrastive and Bias-Balanced Learning (NeurIPS 2021) Official Pytorch implementation of Unbiased Classification

Youngkyu 17 Jan 1, 2023
This is an official PyTorch implementation of Task-Adaptive Neural Network Search with Meta-Contrastive Learning (NeurIPS 2021, Spotlight).

NeurIPS 2021 (Spotlight): Task-Adaptive Neural Network Search with Meta-Contrastive Learning This is an official PyTorch implementation of Task-Adapti

Wonyong Jeong 15 Nov 21, 2022
Official implementation of "Open-set Label Noise Can Improve Robustness Against Inherent Label Noise" (NeurIPS 2021)

Open-set Label Noise Can Improve Robustness Against Inherent Label Noise NeurIPS 2021: This repository is the official implementation of ODNL. Require

Hongxin Wei 12 Dec 7, 2022
Official implementation of Generalized Data Weighting via Class-level Gradient Manipulation (NeurIPS 2021).

Generalized Data Weighting via Class-level Gradient Manipulation This repository is the official implementation of Generalized Data Weighting via Clas

null 9 Nov 3, 2021
The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

The official implementation of NeurIPS 2021 paper: Finding Optimal Tangent Points for Reducing Distortions of Hard-label Attacks

machen 11 Nov 27, 2022
Official implementation of Neural Bellman-Ford Networks (NeurIPS 2021)

NBFNet: Neural Bellman-Ford Networks This is the official codebase of the paper Neural Bellman-Ford Networks: A General Graph Neural Network Framework

MilaGraph 136 Dec 21, 2022
Official Pytorch implementation for Deep Contextual Video Compression, NeurIPS 2021

Introduction Official Pytorch implementation for Deep Contextual Video Compression, NeurIPS 2021 Prerequisites Python 3.8 and conda, get Conda CUDA 11

null 51 Dec 3, 2022
Official implementation of NeurIPS'2021 paper TransformerFusion

TransformerFusion: Monocular RGB Scene Reconstruction using Transformers Project Page | Paper | Video TransformerFusion: Monocular RGB Scene Reconstru

Aljaz Bozic 118 Dec 25, 2022
Official code for On Path Integration of Grid Cells: Group Representation and Isotropic Scaling (NeurIPS 2021)

On Path Integration of Grid Cells: Group Representation and Isotropic Scaling This repo contains the official implementation for the paper On Path Int

Ruiqi Gao 39 Nov 10, 2022
Official implementation of "GS-WGAN: A Gradient-Sanitized Approach for Learning Differentially Private Generators" (NeurIPS 2020)

GS-WGAN This repository contains the implementation for GS-WGAN: A Gradient-Sanitized Approach for Learning Differentially Private Generators (NeurIPS

null 46 Nov 9, 2022
Official implementation for Likelihood Regret: An Out-of-Distribution Detection Score For Variational Auto-encoder at NeurIPS 2020

Likelihood-Regret Official implementation of Likelihood Regret: An Out-of-Distribution Detection Score For Variational Auto-encoder at NeurIPS 2020. T

Xavier 33 Oct 12, 2022
Official Pytorch implementation of 'GOCor: Bringing Globally Optimized Correspondence Volumes into Your Neural Network' (NeurIPS 2020)

Official implementation of GOCor This is the official implementation of our paper : GOCor: Bringing Globally Optimized Correspondence Volumes into You

Prune Truong 71 Nov 18, 2022
Official Implementation of Swapping Autoencoder for Deep Image Manipulation (NeurIPS 2020)

Swapping Autoencoder for Deep Image Manipulation Taesung Park, Jun-Yan Zhu, Oliver Wang, Jingwan Lu, Eli Shechtman, Alexei A. Efros, Richard Zhang UC

null 449 Dec 27, 2022
Pytorch implementation of RED-SDS (NeurIPS 2021).

Recurrent Explicit Duration Switching Dynamical Systems (RED-SDS) This repository contains a reference implementation of RED-SDS, a non-linear state s

Abdul Fatir 10 Dec 2, 2022
The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL), NeurIPS-2021

Directed Graph Contrastive Learning The PyTorch implementation of Directed Graph Contrastive Learning (DiGCL). In this paper, we present the first con

Tong Zekun 28 Jan 8, 2023
PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

PyTorch implementation of NeurIPS 2021 paper: "CoFiNet: Reliable Coarse-to-fine Correspondences for Robust Point Cloud Registration"

null 76 Jan 3, 2023
A tensorflow=1.13 implementation of Deconvolutional Networks on Graph Data (NeurIPS 2021)

GDN A tensorflow=1.13 implementation of Deconvolutional Networks on Graph Data (NeurIPS 2021) Abstract In this paper, we consider an inverse problem i

null 4 Sep 13, 2022
PyTorch implementation for our NeurIPS 2021 Spotlight paper "Long Short-Term Transformer for Online Action Detection".

Long Short-Term Transformer for Online Action Detection Introduction This is a PyTorch implementation for our NeurIPS 2021 Spotlight paper "Long Short

null 77 Dec 16, 2022
[NeurIPS 2020] Official repository for the project "Listening to Sound of Silence for Speech Denoising"

Listening to Sounds of Silence for Speech Denoising Introduction This is the repository of the "Listening to Sounds of Silence for Speech Denoising" p

Henry Xu 40 Dec 20, 2022