Distributionally robust neural networks for group shifts

Overview

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

This code implements the group DRO algorithm from the following paper:

Shiori Sagawa*, Pang Wei Koh*, Tatsunori Hashimoto, and Percy Liang

Distributionally Robust Neural Networks for Group Shifts: On the Importance of Regularization for Worst-Case Generalization

The experiments use the following datasets:

For an executable, Dockerized version of the experiments in these paper, please see our Codalab worksheet.

Abstract

Overparameterized neural networks can be highly accurate on average on an i.i.d. test set yet consistently fail on atypical groups of the data (e.g., by learning spurious correlations that hold on average but not in such groups). Distributionally robust optimization (DRO) allows us to learn models that instead minimize the worst-case training loss over a set of pre-defined groups. However, we find that naively applying group DRO to overparameterized neural networks fails: these models can perfectly fit the training data, and any model with vanishing average training loss also already has vanishing worst-case training loss. Instead, their poor worst-case performance arises from poor generalization on some groups. By coupling group DRO models with increased regularization---stronger-than-typical L2 regularization or early stopping---we achieve substantially higher worst-group accuracies, with 10-40 percentage point improvements on a natural language inference task and two image tasks, while maintaining high average accuracies. Our results suggest that regularization is critical for worst-group generalization in the overparameterized regime, even if it is not needed for average generalization. Finally, we introduce and give convergence guarantees for a stochastic optimizer for the group DRO setting, underpinning the empirical study above.

Prerequisites

  • python 3.6.8
  • matplotlib 3.0.3
  • numpy 1.16.2
  • pandas 0.24.2
  • pillow 5.4.1
  • pytorch 1.1.0
  • pytorch_transformers 1.2.0
  • torchvision 0.5.0a0+19315e3
  • tqdm 4.32.2

Datasets and code

To run our code, you will need to change the root_dir variable in data/data.py. The main point of entry to the code is run_expt.py. Below, we provide sample commands for each dataset.

CelebA

Our code expects the following files/folders in the [root_dir]/celebA directory:

  • data/list_eval_partition.csv
  • data/list_attr_celeba.csv
  • data/img_align_celeba/

You can download these dataset files from this Kaggle link. The original dataset, due to Liu et al. (2015), can be found here. The version of the CelebA dataset that we use in the paper (with the (hair, gender) groups) can also be accessed through the WILDS package, which will automatically download the dataset.

A sample command to run group DRO on CelebA is: python run_expt.py -s confounder -d CelebA -t Blond_Hair -c Male --lr 0.0001 --batch_size 128 --weight_decay 0.0001 --model resnet50 --n_epochs 50 --reweight_groups --robust --gamma 0.1 --generalization_adjustment 0

Waterbirds

The Waterbirds dataset is constructed by cropping out birds from photos in the Caltech-UCSD Birds-200-2011 (CUB) dataset (Wah et al., 2011) and transferring them onto backgrounds from the Places dataset (Zhou et al., 2017).

Our code expects the following files/folders in the [root_dir]/cub directory:

  • data/waterbird_complete95_forest2water2/

You can download a tarball of this dataset here. The Waterbirds dataset can also be accessed through the WILDS package, which will automatically download the dataset.

A sample command to run group DRO on Waterbirds is: python run_expt.py -s confounder -d CUB -t waterbird_complete95 -c forest2water2 --lr 0.001 --batch_size 128 --weight_decay 0.0001 --model resnet50 --n_epochs 300 --reweight_groups --robust --gamma 0.1 --generalization_adjustment 0

Note that compared to the training set, the validation and test sets are constructed with different proportions of each group. We describe this in more detail in Appendix C.1 of our paper, which we reproduce here for convenience:

We use the official train-test split of the CUB dataset, randomly choosing 20% of the training data to serve as a validation set. For the validation and test sets, we allocate distribute landbirds and waterbirds equally to land and water backgrounds (i.e., there are the same number of landbirds on land vs. water backgrounds, and separately, the same number of waterbirds on land vs. water backgrounds). This allows us to more accurately measure the performance of the rare groups, and it is particularly important for the Waterbirds dataset because of its relatively small size; otherwise, the smaller groups (waterbirds on land and landbirds on water) would have too few samples to accurately estimate performance on. We note that we can only do this for the Waterbirds dataset because we control the generation process; for the other datasets, we cannot generate more samples from the rare groups.

In a typical application, the validation set might be constructed by randomly dividing up the available training data. We emphasize that this is not the case here: the training set is skewed, whereas the validation set is more balanced. We followed this construction so that we could better compare ERM vs. reweighting vs. group DRO techniques using a stable set of hyperparameters. In practice, if the validation set were also skewed, we might expect hyperparameter tuning based on worst-group accuracy to be more challenging and noisy.

Due to the above procedure, when reporting average test accuracy in our experiments, we calculate the average test accuracy over each group and then report a weighted average, with weights corresponding to the relative proportion of each group in the (skewed) training dataset.

If you'd like to generate variants of this dataset, we have included the script we used to generate this dataset (from the CUB and Places datasets) in dataset_scripts/generate_waterbirds.py. Note that running this script will not create the exact dataset we provide above, due to random seed differences. You will need to download the CUB dataset as well as the Places dataset. We use the high-resolution training images (MD5: 67e186b496a84c929568076ed01a8aa1) from Places. Once you have downloaded and extracted these datasets, edit the corresponding paths in generate_waterbirds.py.

MultiNLI with annotated negations

Our code expects the following files/folders in the [root_dir]/multinli directory:

  • data/metadata_random.csv
  • glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli
  • glue_data/MNLI/cached_dev_bert-base-uncased_128_mnli-mm
  • glue_data/MNLI/cached_train_bert-base-uncased_128_mnli

We have included the metadata file in dataset_metadata/multinli in this repository. The metadata file records whether each example belongs to the train/val/test dataset as well as whether it contains a negation word.

The glue_data/MNLI files are generated by the huggingface Transformers library and can be downloaded here.

A sample command to run group DRO on MultiNLI is: python run_expt.py -s confounder -d MultiNLI -t gold_label_random -c sentence2_has_negation --lr 2e-05 --batch_size 32 --weight_decay 0 --model bert --n_epochs 3 --reweight_groups --robust --generalization_adjustment 0

We created our own train/val/test split of the MultiNLI dataset, as described in Appendix C.1 of our paper:

The standard MultiNLI train-test split allocates most examples (approximately 90%) to the training set, with another 5% as a publicly-available development set and the last 5% as a held-out test set that is only accessible through online competition leaderboards (Williams et al., 2018). To accurately estimate performance on rare groups in the validation and test sets, we combine the training set and development set and then randomly resplit it to a 50-20-30 train-val-test split that allocates more examples to the validation and test sets than the standard split.

If you'd like to modify the metadata file (e.g., considering other confounders than the presence of negation words), we have included the script we used to generate the metadata file in dataset_scripts/generate_multinli.py. Note that running this script will not create the exact dataset we provide above, due to random seed differences. You will need to download the MultiNLI dataset and edit the paths in that script accordingly.

Comments
  • Where is the random group chosen?

    Where is the random group chosen?

    First, thanks for the nice work! In your paper you show the following: image

    I am having difficulty trying to find the part in your code corresponding to randomly picking a group: g~Uniform(1,...m) Please could you tell me where can I find it?

    opened by chanshing 3
  • not being able to reproduce your results on CUB and CelebA

    not being able to reproduce your results on CUB and CelebA

    With the command in repo for MNLI, I am not able to reproduce your results. Could you send me the command which reproduce your best results on CUB and CeleBA? Especially for that with Large Weight Decay and Early stop Thanks a lot.

    opened by linyongver 2
  • not being able to reproduce your results on MNLI

    not being able to reproduce your results on MNLI

    Hi With the command in repo for MNLI, I am not able to reproduce your results. Could you send me the command which reproduce your best results on MNLI? Basically seems some argument are missing from that command. thanks.

    opened by ghost 2
  • Codalab worksheet link is broken

    Codalab worksheet link is broken

    The link to the Codalab worksheet is broken. The page shows "Not found: '/worksheets/0x621811fe446b49bb818293bae2ef88c0'." Could you please update it? Thank you!

    opened by zhihengli-UR 1
  • ERM should also save best model based on worst-group accuracy

    ERM should also save best model based on worst-group accuracy

    Hi @kohpangwei and @ssagawa ,

    Your paper mentioned that "All benchmark models were evaluated at the best early stopping epoch (as measured by robust validation accuracy)." However, your code https://github.com/kohpangwei/group_DRO/blob/ca58872bd5a7b5fe90c35d4c39504babe76e3532/train.py#L205-L209 indicates that (i) for ERM, the best model is determined by the average validation accuracy (ii) for reweighting and GroupDRO, the best model is determined by the worst-group accuracy. I think for a fair comparison, the model selection rule of (i) should be changed to be identical to (ii), what do you think?

    Did you use the model selection rule (i) for ERM in your paper's experiments (e.g., Table 3)? I'm trying to reproduce your results, but I'm not sure if your results on ERM are from (i) or not.

    opened by Haoxiang-Wang 1
  • Algorithm in paper and random groups per batch

    Algorithm in paper and random groups per batch

    Following up on previous question https://github.com/kohpangwei/group_DRO/issues/7 Please can you clarify whether we need to sample only one group at each iteration or it is OK to have multiple groups in a batch? In your algorithm, it seems to say that we need to sample only one group at each iteration, but this doesn't seem to be the case in the code.

    image

    Additionally, please can you comment on the following remark from this paper https://arxiv.org/abs/2010.12230 ?

    image

    Is this remark justified?

    opened by chanshing 1
  • Invalid link

    Invalid link

    Hi,

    I cannot open the link that you provide (https://nlp.stanford.edu/data/dro/waterbird_complete95_forest2water2.tar.gz) for waterbird dataset. Could you create a new link to access the dataset? Thanks a lot!

    opened by Newbeeer 1
  • Why is reweight_groups flag set for DRO algorithm? Possible unfair comparison to ERM?

    Why is reweight_groups flag set for DRO algorithm? Possible unfair comparison to ERM?

    Dear authors, Thank you for sharing a well polished codebase!

    For the results in table 1, I have noticed that DRO method is always run with "reweight_groups" flag set to "True", whereas the same flag is "False" for the ERM algorithm [1]. As per the code, the "reweight_groups" flag performs a weighted random sampling guaranteeing an equal count of each group in any given batch. On the other hand, the ERM algorithm receives a smaller count of the minority sample as there is no weighted random sampling. Such a difference in implementation between ERM and GroupDRO suggests for an unfair comparison between the two methods.

    Surely, as pointed out in the comment [2], the loss function could be considered unaffected by the "reweight_groups" flag as the DRO method uses the mean of per-group losses. However, the empirical estimate of these means in a given batch would be highly noisy when the sample count of the minority group is very small. This makes me wonder (and I hope it's okay for me to ask), that the gains reported in the paper are attributed solely to the use of weighted random sampling procedure rather than DRO update rule? Please clarify

    Do you have any comparisons of the DRO algorithm with "reweight_groups" flag set to "False"? How does ERM with "reweight_flag=True" compare to ERM with "reweight_flag=False"?

    Thank you

    [1] https://worksheets.codalab.org/worksheets/0x621811fe446b49bb818293bae2ef88c0 [2] https://github.com/kohpangwei/group_DRO/blob/f7eae929bf4f9b3c381fae6b1b53ab4c6c911a0e/data/dro_dataset.py#L56

    opened by lokhande-vishnu 1
  • generating MNLI glue files

    generating MNLI glue files

    Hi Could you provide the command to re-generate your cached_mnli_files. Also, Is this possible to have the code working from the raw text data of MNLI. thanks.

    opened by ghost 1
  • Could you provide the csv files?

    Could you provide the csv files?

    Hi,

    Could you provide these files?

    • list_attr_celeba.csv
    • list_eval_partition.csv

    It would be easier to just use what you already have, instead of converting them on my own.

    Edit: For now, I will be using these: https://raw.githubusercontent.com/togheppi/cDCGAN/master/list_attr_celeba.csv https://raw.githubusercontent.com/Golbstein/keras-face-recognition/master/list_eval_partition.csv

    You could just close this issue.

    opened by erobic 1
Owner
null
Code for paper: Group-CAM: Group Score-Weighted Visual Explanations for Deep Convolutional Networks

Group-CAM By Zhang, Qinglong and Rao, Lu and Yang, Yubin [State Key Laboratory for Novel Software Technology at Nanjing University] This repo is the o

zhql 98 Nov 16, 2022
A machine learning benchmark of in-the-wild distribution shifts, with data loaders, evaluators, and default models.

WILDS is a benchmark of in-the-wild distribution shifts spanning diverse data modalities and applications, from tumor identification to wildlife monitoring to poverty mapping.

P-Lambda 437 Dec 30, 2022
PyTorch evaluation code for Delving Deep into the Generalization of Vision Transformers under Distribution Shifts.

Out-of-distribution Generalization Investigation on Vision Transformers This repository contains PyTorch evaluation code for Delving Deep into the Gen

Chongzhi Zhang 72 Dec 13, 2022
CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)

CrossNorm (CN) and SelfNorm (SN) (Accepted at ICCV 2021) This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm

null 100 Dec 28, 2022
CrossNorm and SelfNorm for Generalization under Distribution Shifts (ICCV 2021)

CrossNorm (CN) and SelfNorm (SN) (Accepted at ICCV 2021) This is the official PyTorch implementation of our CNSN paper, in which we propose CrossNorm

null 100 Dec 28, 2022
BC3407-Group-5-Project - BC3407 Group Project With Python

BC3407-Group-5-Project As the world struggles to contain the ever-changing varia

null 1 Jan 26, 2022
LBK 35 Dec 26, 2022
A certifiable defense against adversarial examples by training neural networks to be provably robust

DiffAI v3 DiffAI is a system for training neural networks to be provably robust and for proving that they are robust. The system was developed for the

SRI Lab, ETH Zurich 202 Dec 13, 2022
[NeurIPS2021] Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks

Exploring Architectural Ingredients of Adversarially Robust Deep Neural Networks Code for NeurIPS 2021 Paper "Exploring Architectural Ingredients of A

Hanxun Huang 26 Dec 1, 2022
Complex-Valued Neural Networks (CVNN)Complex-Valued Neural Networks (CVNN)

Complex-Valued Neural Networks (CVNN) Done by @NEGU93 - J. Agustin Barrachina Using this library, the only difference with a Tensorflow code is that y

youceF 1 Nov 12, 2021
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
NR-GAN: Noise Robust Generative Adversarial Networks

NR-GAN: Noise Robust Generative Adversarial Networks (CVPR 2020) This repository provides PyTorch implementation for noise robust GAN (NR-GAN). NR-GAN

Takuhiro Kaneko 59 Dec 11, 2022
Code for the CVPR 2021 paper: Understanding Failures of Deep Networks via Robust Feature Extraction

Welcome to Barlow Barlow is a tool for identifying the failure modes for a given neural network. To achieve this, Barlow first creates a group of imag

Sahil Singla 33 Dec 5, 2022
[TIP 2021] SADRNet: Self-Aligned Dual Face Regression Networks for Robust 3D Dense Face Alignment and Reconstruction

SADRNet Paper link: SADRNet: Self-Aligned Dual Face Regression Networks for Robust 3D Dense Face Alignment and Reconstruction Requirements python

Multimedia Computing Group, Nanjing University 99 Dec 30, 2022
A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks

A framework that constructs deep neural networks, autoencoders, logistic regressors, and linear networks without the use of any outside machine learning libraries - all from scratch.

Kordel K. France 2 Nov 14, 2022
The official implementation of NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021]. https://arxiv.org/pdf/2101.12378.pdf

NeMo: Neural Mesh Models of Contrastive Features for Robust 3D Pose Estimation [ICLR-2021] Release Notes The offical PyTorch implementation of NeMo, p

Angtian Wang 76 Nov 23, 2022
The official implementation of the IEEE S&P`22 paper "SoK: How Robust is Deep Neural Network Image Classification Watermarking".

Watermark-Robustness-Toolbox - Official PyTorch Implementation This repository contains the official PyTorch implementation of the following paper to

null 49 Dec 19, 2022
Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21)

AdvRush Official Code for AdvRush: Searching for Adversarially Robust Neural Architectures (ICCV '21) Environmental Set-up Python == 3.6.12, PyTorch =

null 11 Dec 10, 2022
Code to reproduce the results for Statistically Robust Neural Network Classification, published in UAI 2021

Code to reproduce the results for Statistically Robust Neural Network Classification, published in UAI 2021

null 1 Jun 2, 2022