Understanding Bayesian Classification
This repository hosts the code to reproduce the results presented in the paper On Uncertainty, Tempering, and Data Augmentation in Bayesian Classification by Sanyam Kapoor, Wesley J Maddox, Pavel Izmailov, and Andrew Gordon Wilson.
Key Ideas
Aleatoric uncertainty captures the inherent randomness of the data, such as measurement noise. In Bayesian regression, we often use a Gaussian observation model, where we control the level of aleatoric uncertainty with a noise variance parameter. By contrast, for Bayesian classification we use a categorical distribution with no mechanism to represent our beliefs about aleatoric uncertainty. Our work shows that:
- Explicitly accounting for aleatoric uncertainty significantly improves the performance of Bayesian neural networks.
- We can match or exceed the performance of posterior tempering by using a Dirichlet observation model, where we explicitly control the level of aleatoric uncertainty, without any need for tempering.
- The cold posterior effect is effectively eliminated by properly accounting for aleatoric uncertainty in the likelihood model.
Setup
All requirements are listed in environment.yml. Create a conda
environment using:
conda env create -n <env_name>
Next, ensure Python modules under the src
folder are importable as,
export PYTHONPATH="$(pwd)/src:${PYTHONPATH}"
To use bnn_priors, see respective installation instructions.
Usage
The main script to run all SGMCMC experiments is experiments/train_lik.py.
As an example, to run cyclical SGHMC with our proposed noisy Dirichlet likelihood on CIFAR-10 with label noise, run:
python experiments/train_lik.py --dataset=cifar10 \
--label_noise=0.2 \
--likelihood=dirichlet \
--noise=1e-2 \
--prior-scale=1 \
--sgld-epochs=1000 \
--sgld-lr=2e-7 \
--n-cycles=50 \
--n-samples=50
Each argument to the main
method can be used as a command line argument due to Fire. Weights & Biases is used for all logging. Configurations for various Weights & Biases sweeps are also available under configs.
License
Apache 2.0