Bayes-Newton—A Gaussian process library in JAX, with a unifying view of approximate Bayesian inference as variants of Newton's algorithm.

Overview

Bayes-Newton

Bayes-Newton is a library for approximate inference in Gaussian processes (GPs) in JAX (with objax), built and actively maintained by Will Wilkinson.

Bayes-Newton provides a unifying view of approximate Bayesian inference, and allows for the combination of many models (e.g. GPs, sparse GPs, Markov GPs, sparse Markov GPs) with the inference method of your choice (VI, EP, Laplace, Linearisation). For a full list of the methods implemented scroll down to the bottom of this page.

Installation

pip install bayesnewton

Example

Given some inputs x and some data y, you can construct a Bayes-Newton model as follows,

kern = bayesnewton.kernels.Matern52()
lik = bayesnewton.likelihoods.Gaussian()
model = bayesnewton.models.MarkovVariationalGP(kernel=kern, likelihood=lik, X=x, Y=y)

The training loop (inference and hyperparameter learning) is then set up using objax's built in functionality:

lr_adam = 0.1
lr_newton = 1
opt_hypers = objax.optimizer.Adam(model.vars())
energy = objax.GradValues(model.energy, model.vars())

@objax.Function.with_vars(model.vars() + opt_hypers.vars())
def train_op():
    model.inference(lr=lr_newton, **inf_args)  # perform inference and update variational params
    dE, E = energy(**inf_args)  # compute energy and its gradients w.r.t. hypers
    opt_hypers(lr_adam, dE)  # update the hyperparameters
    return E

As we are using JAX, we can JIT compile the training loop:

train_op = objax.Jit(train_op)

and then run the training loop,

iters = 20
for i in range(1, iters + 1):
    loss = train_op()

Full demos are available here.

License

This software is provided under the Apache License 2.0. See the accompanying LICENSE file for details.

Citing Bayes-Newton

@article{wilkinson2021bayesnewton,
  title = {{B}ayes-{N}ewton Methods for Approximate {B}ayesian Inference with {PSD} Guarantees},
  author = {Wilkinson, William J. and S\"arkk\"a, Simo and Solin, Arno},
  journal={arXiv preprint arXiv:2111.01721},
  year={2021}
}

Implemented Models

For a full list of the all the models available see the model class list.

Variational GPs

  • Variationl GP (Opper, Archambeau: The Variational Gaussian Approximation Revisited, Neural Computation 2009; Khan, Lin: Conugate-Computation Variational Inference - Converting Inference in Non-Conjugate Models in to Inference in Conjugate Models, AISTATS 2017)
  • Sparse Variational GP (Hensman, Matthews, Ghahramani: Scalable Variational Gaussian Process Classification, AISTATS 2015; Adam, Chang, Khan, Solin: Dual Parameterization of Sparse Variational Gaussian Processes, NeurIPS 2021)
  • Markov Variational GP (Chang, Wilkinson, Khan, Solin: Fast Variational Learning in State Space Gaussian Process Models, MLSP 2020)
  • Sparse Markov Variational GP (Adam, Eleftheriadis, Durrande, Artemev, Hensman: Doubly Sparse Variational Gaussian Processes, AISTATS 2020; Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Spatio-Temporal Variational GP (Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)

Expectation Propagation GPs

  • Expectation Propagation GP (Minka: A Family of Algorithms for Approximate Bayesian Inference, Ph. D thesis 2000)
  • Sparse Expectation Propagation GP (energy not working) (Csato, Opper: Sparse on-line Gaussian processes, Neural Computation 2002; Bui, Yan, Turner: A Unifying Framework for Gaussian Process Pseudo Point Approximations Using Power Expectation Propagation, JMLR 2017)
  • Markov Expectation Propagation GP (Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Expectation Propagation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Laplace/Newton GPs

  • Laplace GP (Rasmussen, Williams: Gaussian Processes for Machine Learning, 2006)
  • Sparse Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Markov Laplace GP (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
  • Sparse Markov Laplace GP

Linearisation GPs

  • Posterior Linearisation GP (García-Fernández, Tronarp, Sarkka: Gaussian Process Classification Using Posterior Linearization, IEEE Signal Processing 2019; Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Sparse Posterior Linearisation GP
  • Markov Posterior Linearisation GP (García-Fernández, Svensson, Sarkka: Iterated Posterior Linearization Smoother, IEEE Automatic Control 2016; Wilkinson, Chang, Riis Andersen, Solin: State Space Expectation Propagation, ICML 2020)
  • Sparse Markov Posterior Linearisation GP (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)
  • Taylor Expansion / Analytical Linearisaiton GP (Steinberg, Bonilla: Extended and Unscented Gaussian Processes, NeurIPS 2014)
  • Markov Taylor GP / Extended Kalman Smoother (Bell: The Iterated Kalman Smoother as a Gauss-Newton method, SIAM Journal on Optimization 1994)
  • Sparse Taylor GP
  • Sparse Markov Taylor GP / Sparse Extended Kalman Smoother (Wilkinson, Solin, Adam: Sparse Algorithms for Markovian Gaussian Processes, AISTATS 2021)

Gauss-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Gauss-Newton
  • Variational Gauss-Newton
  • PEP Gauss-Newton
  • 2nd-order PL Gauss-Newton

Quasi-Newton GPs

(Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

  • Quasi-Newton
  • Variational Quasi-Newton
  • PEP Quasi-Newton
  • PL Quasi-Newton

GPs with PSD Constraints via Riemannian Gradients

  • VI Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • Newton/Laplace Riemann Grad (Lin, Schmidt, Khan: Handling the Positive-Definite Constraint in the Bayesian Learning Rule, ICML 2020)
  • PEP Riemann Grad (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)

Others

  • Infinite Horizon GP (Solin, Hensman, Turner: Infinite-Horizon Gaussian Processes, NeurIPS 2018)
  • Parallel Markov GP (with VI, EP, PL, ...) (Särkkä, García-Fernández: Temporal parallelization of Bayesian smoothers; Corenflos, Zhao, Särkkä: Gaussian Process Regression in Logarithmic Time; Hamelijnck, Wilkinson, Loppi, Solin, Damoulas: Spatio-Temporal Variational Gaussian Processes, NeurIPS 2021)
  • 2nd-order Posterior Linearisation GP (sparse, Markov, ...) (Wilkinson, Särkkä, Solin: Bayes-Newton Methods for Approximate Bayesian Inference with PSD Guarantees)
Comments
  • Addition of a Squared Exponential Kernel?

    Addition of a Squared Exponential Kernel?

    This is not an issue per se, but I was wondering if there was a specific reason that there wasn't a Squared Exponential kernel as part of the package?

    If applicable, I would be happy to submit a PR adding one.

    Just let me know.

    opened by mathDR 7
  • Add sq exponential kernel

    Add sq exponential kernel

    This PR does two things:

    1. Adds a Squared Exponential Kernel. Adopting the gpflow nomenclature for naming, i.e. SquaredExponential inheriting from StationaryKernel. As of now, only the K_r method is populated (others may be populated as needed).
    2. The marathon.py demo was changed to take the SquaredExponential kernel (with lengthscale = 40). This serves as a "test" to ensure that the kernel runs.
    opened by mathDR 2
  • question about the equation (64) in `Bayes-Newton` paper

    question about the equation (64) in `Bayes-Newton` paper

    I'm a little confused by the equation (64). It is calculated by equation (63), but where is the other term in equation (64)? such as denominator comes from the covariance of p(fn|u).

    opened by Fangwq 2
  • jitted predict

    jitted predict

    Hi, I'm starting to explore your framework. I'm familiar with jax, but not with objax. I noticed that train ops are jitted with objax.Jit, but as my goal is to have fast prediction embedded in some larger jax code, I wonder if predit() can be also jitted? Thanks in advance,

    Regards,

    opened by soldierofhell 2
  • error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    error in heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method

    As the title said, there is an error after running heteroscedastic.py with MarkovPosteriorLinearisationQuasiNewtonGP method:

    File "heteroscedastic.py", line 101, in train_op
    model.inference(lr=lr_newton, damping=damping)  # perform inference and update variational params
    File "/BayesNewton-main/bayesnewton/inference.py",  line 871, in inference
    mean, jacobian, hessian, quasi_newton_state =self.update_variational_params(batch_ind, lr, **kwargs)
    File "/BayesNewton-main/bayesnewton/inference.py",
    line 1076, in update_variational_params
    jacobian_var = transpose(solve(omega, dmu_dv)) @ residual
    ValueError: The arguments to solve must have shapes a=[..., m, m] and b=[..., m, k] or b=[..., m]; got a=(117, 1, 1) and b=(117, 2, 2)
    

    Can you tell me where it is wrong ? Thanks in advance.

    opened by Fangwq 2
  • How to set initial `Pinf` variable in kernel?

    How to set initial `Pinf` variable in kernel?

    I note that the initial Pinf variable for Matern-5/2 kernel is as follows:

            Pinf = np.array([[self.variance,    0.0,   -kappa],
                             [0.0,    kappa, 0.0],
                             [-kappa, 0.0,   25.0*self.variance / self.lengthscale**4.0]])
    

    Why it is like that? Any references I should follow up?

    PS: by the way, the data ../data/aq_data.csv is missing.

    opened by Fangwq 2
  • How to install Newt in conda virtual environment?

    How to install Newt in conda virtual environment?

    Hi, thank you for sharing your great work.

    I am a little confused about how to install Newt in a conda VE. I really appreciate it if you could guide in this regard. Thank you

    opened by mohammad-saber 2
  • How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    How to understand the function `cavity_distribution_tied` in file `basemodels.py` ?

    Just as the title said, how can I understand cavity_distribution_tied in file basemodels.py? Is there any reference I should follow up? And I note that this code is similar to equation (64) in the BayesNewton paper. How does it come from?

    opened by Fangwq 1
  • issue with SparseVariationalGP method

    issue with SparseVariationalGP method

    When I run the code file demos/regression.py with SparseVariationalGP, something wrong happens:

    AssertionError: Assignments to variable must be an instance of JaxArray, but received f<class 'numpy.ndarray'>.
    

    It seems that a mistake in method SparseVariationalGP . Can you help to fix the problem? Thank you very much!

    opened by Fangwq 1
  • Sparse EP energy is incorrect

    Sparse EP energy is incorrect

    The current implementation of the sparse EP energy is not giving sensible results. This is a reminder to look into the reasons why and check against implementations elsewhere. PRs very welcome for this issue.

    Note: the EP energy is correct for all other models (GP, Markov GP, SparseMarkovGP)

    bug 
    opened by wil-j-wil 1
  • Double Precision Issues

    Double Precision Issues

    Hi!

    Many thanks for open-sourcing this package.

    I've been using code from this amazing package for my research (preprint here) and have found that the default single precision/float32 is insufficient for the Kalman filtering and smoothing operations, causing numerical instabilities. In particular,

    • for the Periodic kernel, it is rather sensitive to the matrix operations in _sequential_kf() and _sequential_rts().
    • Likewise, the same when the lengthscales are too large in the Matern32 kernel.

    However, reverting to float64 by setting config.update("jax_enable_x64", True) makes everything quite slow, especially when I use objax neural network modules, due to the fact that doing so puts all arrays into double precision.

    Currently, my solution is to set the neural network weights to float32 manually, and convert input arrays into float32 before entering the network and the outputs back into float64. However, I was wondering if there could be a more elegant solution as is done in https://github.com/thomaspinder/GPJax, where all arrays are assumed to be float64. My understanding is that their package depends on Haiku, but I'm unsure how they got around the computational scalability issue.

    Software and hardware details:

    objax==1.6.0
    jax==0.3.13
    jaxlib==0.3.10+cuda11.cudnn805
    
    NVIDIA-SMI 460.56       Driver Version: 460.56       CUDA Version: 11.2
    GeForce RTX 3090 GPUs
    

    Thanks in advance.

    Best, Harrison

    opened by harrisonzhu508 1
  • Cannot run demo, possible incompatibility with latest Jax

    Cannot run demo, possible incompatibility with latest Jax

    Dear all,

    I am trying to run the demo examples, but I run in the following error


    ImportError Traceback (most recent call last) Input In [22], in <cell line: 1>() ----> 1 import bayesnewton 2 import objax 3 import numpy as np

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/init.py:1, in ----> 1 from . import ( 2 kernels, 3 utils, 4 ops, 5 likelihoods, 6 models, 7 basemodels, 8 inference, 9 cubature 10 ) 13 def build_model(model, inf, name='GPModel'): 14 return type(name, (inf, model), {})

    File ~/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/bayesnewton/kernels.py:5, in 3 import jax.numpy as np 4 from jax.scipy.linalg import cho_factor, cho_solve, block_diag, expm ----> 5 from jax.ops import index_add, index 6 from .utils import scaled_squared_euclid_dist, softplus, softplus_inv, rotation_matrix 7 from warnings import warn

    ImportError: cannot import name 'index_add' from 'jax.ops' (/Users/Daniel/opt/anaconda3/envs/aurora/lib/python3.9/site-packages/jax/ops/init.py)

    I think its related to this from the Jax website:

    The functions jax.ops.index_update, jax.ops.index_add, etc., which were deprecated in JAX 0.2.22, have been removed. Please use the jax.numpy.ndarray.at property on JAX arrays instead.

    opened by daniel-trejobanos 3
  • Latest versions of JAX and objax cause compile slow down

    Latest versions of JAX and objax cause compile slow down

    It is recommended to use the following versions of jax and objax:

    jax==0.2.9
    jaxlib==0.1.60
    objax==1.3.1
    

    This is because of this objax issue which causes the model to JIT compile "twice", i.e. on the first two iterations rather than just the first. This causes a bit of a slow down for large models, but is not an problem otherwise.

    bug 
    opened by wil-j-wil 0
Releases(v1.2.0)
Owner
AaltoML
Machine learning group at Aalto University lead by Prof. Solin
AaltoML
A bare-bones TensorFlow framework for Bayesian deep learning and Gaussian process approximation

Aboleth A bare-bones TensorFlow framework for Bayesian deep learning and Gaussian process approximation [1] with stochastic gradient variational Bayes

Gradient Institute 127 Dec 12, 2022
Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks

Bayesian-Torch is a library of neural network layers and utilities extending the core of PyTorch to enable the user to perform stochastic variational inference in Bayesian deep neural networks. Bayesian-Torch is designed to be flexible and seamless in extending a deterministic deep neural network architecture to corresponding Bayesian form by simply replacing the deterministic layers with Bayesian layers.

Intel Labs 210 Jan 4, 2023
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
LBK 20 Dec 2, 2022
aka "Bayesian Methods for Hackers": An introduction to Bayesian methods + probabilistic programming with a computation/understanding-first, mathematics-second point of view. All in pure Python ;)

Bayesian Methods for Hackers Using Python and PyMC The Bayesian method is the natural approach to inference, yet it is hidden from readers behind chap

Cameron Davidson-Pilon 25.1k Jan 2, 2023
[ICCV'21] UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction

UNISURF: Unifying Neural Implicit Surfaces and Radiance Fields for Multi-View Reconstruction Project Page | Paper | Supplementary | Video This reposit

null 331 Dec 28, 2022
This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CNPs), Neural Processes (NPs), Attentive Neural Processes (ANPs).

The Neural Process Family This repository contains notebook implementations of the following Neural Process variants: Conditional Neural Processes (CN

DeepMind 892 Dec 28, 2022
Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Optimization Algorithm,Immune Algorithm, Artificial Fish Swarm Algorithm, Differential Evolution and TSP(Traveling salesman)

scikit-opt Swarm Intelligence in Python (Genetic Algorithm, Particle Swarm Optimization, Simulated Annealing, Ant Colony Algorithm, Immune Algorithm,A

郭飞 3.7k Jan 3, 2023
A semismooth Newton method for elliptic PDE-constrained optimization

sNewton4PDEOpt The Python module implements a semismooth Newton method for solving finite-element discretizations of the strongly convex, linear ellip

null 2 Dec 8, 2022
ViViT: Curvature access through the generalized Gauss-Newton's low-rank structure

ViViT is a collection of numerical tricks to efficiently access curvature from the generalized Gauss-Newton (GGN) matrix based on its low-rank structure. Provided functionality includes computing

Felix Dangel 12 Dec 8, 2022
Blender add-on: Add to Cameras menu: View → Camera, View → Add Camera, Camera → View, Previous Camera, Next Camera

Blender add-on: Camera additions In 3D view, it adds these actions to the View|Cameras menu: View → Camera : set the current camera to the 3D view Vie

German Bauer 11 Feb 8, 2022
(CVPR 2022 - oral) Multi-View Depth Estimation by Fusing Single-View Depth Probability with Multi-View Geometry

Multi-View Depth Estimation by Fusing Single-View Depth Probability with Multi-View Geometry Official implementation of the paper Multi-View Depth Est

Bae, Gwangbin 138 Dec 28, 2022
Multi-Output Gaussian Process Toolkit

Multi-Output Gaussian Process Toolkit Paper - API Documentation - Tutorials & Examples The Multi-Output Gaussian Process Toolkit is a Python toolkit f

GAMES 113 Nov 25, 2022
Official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch.

Multi-speaker DGP This repository provides official implementation of deep Gaussian process (DGP)-based multi-speaker speech synthesis with PyTorch. O

sarulab-speech 24 Sep 7, 2022
Code to reproduce the experiments from our NeurIPS 2021 paper " The Limitations of Large Width in Neural Networks: A Deep Gaussian Process Perspective"

Code To run: python runner.py new --save <SAVE_NAME> --data <PATH_TO_DATA_DIR> --dataset <DATASET> --model <model_name> [options] --n 1000 - train - t

Geoff Pleiss 5 Dec 12, 2022
Node-level Graph Regression with Deep Gaussian Process Models

Node-level Graph Regression with Deep Gaussian Process Models Prerequests our implementation is mainly based on tensorflow 1.x and gpflow 1.x: python

null 1 Jan 16, 2022
This repository contains the data and code for the paper "Diverse Text Generation via Variational Encoder-Decoder Models with Gaussian Process Priors" (SPNLP@ACL2022)

GP-VAE This repository provides datasets and code for preprocessing, training and testing models for the paper: Diverse Text Generation via Variationa

Wanyu Du 18 Dec 29, 2022
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022