Learning Causal Semantic Representation for Out-of-Distribution Prediction
This repository is the official implementation of "Learning Causal Semantic Representation for Out-of-Distribution Prediction" (NeurIPS 2021).
Chang Liu <[email protected]>, Xinwei Sun, Jindong Wang, Haoyue Tang, Tao Li, Tao Qin, Wei Chen, Tie-Yan Liu.
[Paper & Appendix] [Slides] [Video] [Poster]
Introduction
The work proposes a Causal Semantic Generative model (CSG) for OOD generalization (single-source domain generalization) and domain adaptation. The model is developed following a causal reasoning process, and prediction is made by leveraging the causal invariance principle. Training and prediction algorithms are developed based on variational Bayes with a novel design. Theoretical guarantees on the identifiability of the causal factor and the benefits for OOD prediction are presented.
This codebase implements the CSG methods, and implements or integrates various baselines. Most domain adaptation baselines (except BNM) use the dalib package. The experiment setups on the PACS and VLCS datasets are adopted from the domainbed repository. Authorships are clarified in each file or module.
Requirements
The code requires python version >= 3.6, and is based on PyTorch. To install requirements:
pip install -r requirements.txt
Usage
Folder a-mnist
contains scripts to run the experiments on the Shifted-MNIST dataset, and a-imageclef
on the ImageCLEF-DA dataset, and a-domainbed
on the PACS and VLCS datasets (the prefix a-
represents "application").
Go to the respective folder and run the prepare_data.sh
or makedata.sh
script there to prepare the datasets. Run the run_ood.sh
(for OOD generalization methods) and run_da.sh
(for domain adaptation methods) scripts to train the models. Evaluation result (accuracy on the test domain) is printed and written to disk with the model and configurations. See the commands in the script files or python3 main.py --help
for customized usage or hyperparameter tuning.