A library for uncertainty representation and training in neural networks.

Related tags

Deep Learning enn
Overview

Epistemic Neural Networks

A library for uncertainty representation and training in neural networks.

Introduction

Many applications in deep learning requires or benefit from going beyond a point estimte and representing uncertainty about the model. The coherent use of Bayes’ rule and probability theory are the gold standard for updating beliefs and estimating uncertainty. But exact computation quickly becomes infeasible for even simple problems. Modern machine learning has developed an effective toolkit for learning in high-dimensional using a simple and coherent convention. Epistemic neural network (ENN) is a library that provides a similarly simple and coherent convention for defining and training neural networks that represent uncertainty over a hypothesis class of models.

Technical overview

In a supervised setting, For input x_i ∈ X and outputs y_i ∈ Y a point estimate f_θ(x) is trained by fitting the observed data D = {(xi, yi) for i = 1, ..., N} by minimizing a loss function l(θ, D) ∈ R. In epistemic neural networks we introduce the concept of an epistemic index z ∈ I ⊆ R^{n_z} distributed according to some reference distribution p_z(·). An augmented epistemic function approximator then takes the form f_θ(x, z); where the function class fθ(·, z) is a neural network. The index z allows unambiguous identification of a corresponding function value and sampling z corresponds to sampling from the hypothesis class of functions.

On some level, ENNs are purely a notational convenience and most existing approaches to dealing with uncertainty in deep learning can be rephrased in this way. For example, an ensemble of point estimates {f_θ1, ..., f_θK } can be viewed as an ENN with θ = (θ1, .., θK), z ∈ {1, .., K}, and f_θ(x, z) := f_θz(x). However, this simplicity hides a deeper insight: that the process of epistemic update itself can be tackled through the tools of machine learning typically reserved for point estimates, through the addition of this epistemic index. Further, since these machine learning tools were explicitly designed to scale to large and complex problems, they might provide tractable approximations to large scale Bayesian inference even where the exact computations are intractable.

For a more comprehensive overview, see the accompanying paper.

Reproducing NeurIPS experiments

To reproduce the experiments from our paper please see experiments/neurips_2021.

Getting started

You can get started in our colab tutorial without installing anything on your machine.

Installation

We have tested ENN on Python 3.7. To install the dependencies:

  1. Optional: We recommend using a Python virtual environment to manage your dependencies, so as not to clobber your system installation:

    python3 -m venv enn
    source enn/bin/activate
    pip install --upgrade pip setuptools
  2. Install ENN directly from github:

    pip install git+https://github.com/deepmind/enn
  3. Test that you can load ENN by training a simple ensemble ENN.

    from acme.utils.loggers.terminal import TerminalLogger
    
    from enn import losses
    from enn import networks
    from enn import supervised
    from enn.supervised import regression_data
    import optax
    
    # A small dummy dataset
    dataset = regression_data.make_dataset()
    
    # Logger
    logger = TerminalLogger('supervised_regression')
    
    # ENN
    enn = networks.MLPEnsembleMatchedPrior(
        output_sizes=[50, 50, 1],
        num_ensemble=10,
    )
    
    # Loss
    loss_fn = losses.average_single_index_loss(
        single_loss=losses.L2LossWithBootstrap(),
        num_index_samples=10
    )
    
    # Optimizer
    optimizer = optax.adam(1e-3)
    
    # Train the experiment
    experiment = supervised.Experiment(
        enn, loss_fn, optimizer, dataset, seed=0, logger=logger)
    experiment.train(FLAGS.num_batch)

More examples can be found in the colab tutorial.

  1. Optional: run the tests by executing ./test.sh from ENN root directory.

Citing

If you use ENN in your work, please cite the accompanying paper:

@inproceedings{,
    title={Epistemic Neural Networks},
    author={Ian Osband, Zheng Wen, Mohammad Asghari, Morteza Ibrahimi, Xiyuan Lu, Benjamin Van Roy},
    booktitle={arxiv},
    year={2021},
    url={https://arxiv.org/abs/2107.08924}
}
You might also like...
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayesian-Torch is designed to be flexible and seamless in extending a deterministic deep neural network architecture to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers.

TensorFlow implementation for Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How
TensorFlow implementation for Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How

Bayesian Modeling and Uncertainty Quantification for Learning to Optimize: What, Why, and How TensorFlow implementation for Bayesian Modeling and Unce

This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

noisy labels; missing labels; semi-supervised learning; entropy; uncertainty; robustness and generalisation.

ProSelfLC: CVPR 2021 ProSelfLC: Progressive Self Label Correction for Training Robust Deep Neural Networks For any specific discussion or potential fu

[CVPR'21] MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation
[CVPR'21] MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation

MonoRUn MonoRUn: Monocular 3D Object Detection by Reconstruction and Uncertainty Propagation. CVPR 2021. [paper] Hansheng Chen, Yuyao Huang, Wei Tian*

 Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation
Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation

Estimating and Exploiting the Aleatoric Uncertainty in Surface Normal Estimation

A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualization
A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualization

Website, Tutorials, and Docs    Uncertainty Toolbox A python toolbox for predictive uncertainty quantification, calibration, metrics, and visualizatio

Code of Adverse Weather Image Translation with Asymmetric and Uncertainty aware GAN
Code of Adverse Weather Image Translation with Asymmetric and Uncertainty aware GAN

Adverse Weather Image Translation with Asymmetric and Uncertainty-aware GAN (AU-GAN) Official Tensorflow implementation of Adverse Weather Image Trans

Symmetry and Uncertainty-Aware Object SLAM for 6DoF Object Pose Estimation
Symmetry and Uncertainty-Aware Object SLAM for 6DoF Object Pose Estimation

SUO-SLAM This repository hosts the code for our CVPR 2022 paper "Symmetry and Uncertainty-Aware Object SLAM for 6DoF Object Pose Estimation". ArXiv li

Comments
  • Problem occured in enn_demo.ipynb

    Problem occured in enn_demo.ipynb

    Hi!

    I was trying the enn_demo.ipynb on google colab. Everything seems fine until I run this block of code.

    # Train the experiment
    experiment.train(FLAGS.num_batch)
    

    and this error appears. Is there something wrong with the JAX version?

    AttributeError                            Traceback (most recent call last)
    [/usr/local/lib/python3.8/dist-packages/enn/networks/ensembles.py](https://kh9bbgsdon-496ff2e9c6d22116-0-colab.googleusercontent.com/outputframe.html?vrz=colab-20221220-060108-RC02_496713401#) in apply(params, states, inputs, index)
         82       sub_states = jax.tree_map(particle_selector, states)
         83       out, new_sub_states = model.apply(sub_params, sub_states, inputs)
    ---> 84       new_states = jax.tree_multimap(
         85           lambda s, nss: s.at[index, ...].set(nss), states, new_sub_states)
         86       return out, new_states
    
    AttributeError: module 'jax' has no attribute 'tree_multimap'
    

    Thanks, Adam

    opened by fazaghifari 0
  • updated occurences of tree_multimap to tree_map

    updated occurences of tree_multimap to tree_map

    I was playing around with the colabs but noticed they were not working on account of tree_multimap being deprecated in the latest version of JAX. I've just updated every occurrence of tree_multimap with tree_map. With this change, the colab seems to work as intended.

    opened by shyams2 1
  • Cleaned up ipynb notebooks

    Cleaned up ipynb notebooks

    Hi. In notebooks enn/colabs/{enn,epinet}_demo.ipynb:

    • remove unused imports
    • split Import section into General imports, Development imports and ENN imports (was before). Test in colab environment, "Runtime -> Run all" passed without errors.

    Fix link about-pull-requests to docs.github page.

    opened by amrzv 0
  • Colab error with haiku/jax dependency

    Colab error with haiku/jax dependency

    A following error appears on executing the colab notebook:

    AttributeError                            Traceback (most recent call last)
    <ipython-input-7-4863fdbb2553> in <module>()
         13     num_ensemble=FLAGS.index_dim,
         14     prior_scale=FLAGS.prior_scale,
    ---> 15     seed=FLAGS.seed,
         16 )
         17 
    
    4 frames
    /usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in __init__(self, output_sizes, dummy_input, num_ensemble, prior_scale, seed, w_init, b_init)
        137     """Ensemble of MLPs with matched prior functions."""
        138     mlp_priors = make_mlp_ensemble_prior_fns(
    --> 139         output_sizes, dummy_input, num_ensemble, seed)
        140     enn = priors.EnnWithAdditivePrior(
        141         enn=MLPEnsembleEnn(
    
    /usr/local/lib/python3.7/dist-packages/enn/networks/ensembles.py in make_mlp_ensemble_prior_fns(output_sizes, dummy_input, num_ensemble, seed, w_init, b_init)
         90     return hk.Sequential(layers)(x)
         91 
    ---> 92   transformed = hk.without_apply_rng(hk.transform(net_fn))
         93 
         94   prior_fns = []
    
    /usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform(f, apply_rng)
        301         "Replace hk.transform(..., apply_rng=True) with hk.transform(...).")
        302 
    --> 303   return without_state(transform_with_state(f))
        304 
        305 
    
    /usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in transform_with_state(f)
        359   """
        360   analytics.log_once("transform_with_state")
    --> 361   check_not_jax_transformed(f)
        362 
        363   unexpected_tracer_hint = (
    
    /usr/local/lib/python3.7/dist-packages/haiku/_src/transform.py in check_not_jax_transformed(f)
        306 def check_not_jax_transformed(f):
        307   # TODO(tomhennigan): Consider `CompiledFunction = type(jax.jit(lambda: 0))`.
    --> 308   if isinstance(f, (jax.xla.xe.CompiledFunction, jax.xla.xe.PmapFunction)):  # pytype: disable=name-error
        309     raise ValueError("A common error with Haiku is to pass an already jit "
        310                      "(or pmap) decorated function into hk.transform (e.g. "
    
    AttributeError: module 'jaxlib.xla_extension' has no attribute 'PmapFunction'
    
    opened by hstojic 0
Owner
DeepMind
DeepMind
Eff video representation - Efficient video representation through neural fields

Neural Residual Flow Fields for Efficient Video Representations 1. Download MPI

null 41 Jan 6, 2023
Ἀνατομή is a PyTorch library to analyze representation of neural networks

Ἀνατομή is a PyTorch library to analyze representation of neural networks

Ryuichiro Hataya 50 Dec 5, 2022
Fast and scalable uncertainty quantification for neural molecular property prediction, accelerated optimization, and guided virtual screening.

Evidential Deep Learning for Guided Molecular Property Prediction and Discovery Ava Soleimany*, Alexander Amini*, Samuel Goldman*, Daniela Rus, Sangee

Alexander Amini 75 Dec 15, 2022
High-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently.

TL;DR Ignite is a high-level library to help with training and evaluating neural networks in PyTorch flexibly and transparently. Click on the image to

null 4.2k Jan 1, 2023
A machine learning library for spiking neural networks. Supports training with both torch and jax pipelines, and deployment to neuromorphic hardware.

Rockpool Rockpool is a Python package for developing signal processing applications with spiking neural networks. Rockpool allows you to build network

SynSense 21 Dec 14, 2022
Selene is a Python library and command line interface for training deep neural networks from biological sequence data such as genomes.

Selene is a Python library and command line interface for training deep neural networks from biological sequence data such as genomes.

Troyanskaya Laboratory 323 Jan 1, 2023
Complex-Valued Neural Networks (CVNN)Complex-Valued Neural Networks (CVNN)

Complex-Valued Neural Networks (CVNN) Done by @NEGU93 - J. Agustin Barrachina Using this library, the only difference with a Tensorflow code is that y

youceF 1 Nov 12, 2021
Repo for "Event-Stream Representation for Human Gaits Identification Using Deep Neural Networks"

Summary This is the code for the paper Event-Stream Representation for Human Gaits Identification Using Deep Neural Networks by Yanxiang Wang, Xian Zh

zhangxian 54 Jan 3, 2023
Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

Some tentative models that incorporate label propagation to graph neural networks for graph representation learning in nodes, links or graphs.

zshicode 1 Nov 18, 2021