Understanding Bayesian Classification
This repository hosts the code to reproduce the results presented in the paper On Uncertainty, Tempering, and Data Augmentation in Bayesian Classification by Sanyam Kapoor, Wesley J Maddox, Pavel Izmailov, and Andrew Gordon Wilson.
Aleatoric uncertainty captures the inherent randomness of the data, such as measurement noise. In Bayesian regression, we often use a Gaussian observation model, where we control the level of aleatoric uncertainty with a noise variance parameter. By contrast, for Bayesian classification we use a categorical distribution with no mechanism to represent our beliefs about aleatoric uncertainty. Our work shows that:
- Explicitly accounting for aleatoric uncertainty significantly improves the performance of Bayesian neural networks.
- We can match or exceed the performance of posterior tempering by using a Dirichlet observation model, where we explicitly control the level of aleatoric uncertainty, without any need for tempering.
- The cold posterior effect is effectively eliminated by properly accounting for aleatoric uncertainty in the likelihood model.
|BMA test accuracy for the noisy Dirichlet model with noise parameter 1e−6 and the softmax likelihood as a function of posterior temperature on CIFAR-10. The noisy Dirichlet model shows no cold posterior effect.|
All requirements are listed in environment.yml. Create a
conda environment using:
conda env create -n <env_name>
Next, ensure Python modules under the
src folder are importable as,
To use bnn_priors, see respective installation instructions.
The main script to run all SGMCMC experiments is experiments/train_lik.py.
As an example, to run cyclical SGHMC with our proposed noisy Dirichlet likelihood on CIFAR-10 with label noise, run:
python experiments/train_lik.py --dataset=cifar10 \ --label_noise=0.2 \ --likelihood=dirichlet \ --noise=1e-2 \ --prior-scale=1 \ --sgld-epochs=1000 \ --sgld-lr=2e-7 \ --n-cycles=50 \ --n-samples=50
Each argument to the
main method can be used as a command line argument due to Fire. Weights & Biases is used for all logging. Configurations for various Weights & Biases sweeps are also available under configs.