TyXe: Pyro-based BNNs for Pytorch users

Related tags

Deep Learning TyXe
Overview

TyXe: Pyro-based BNNs for Pytorch users

TyXe aims to simplify the process of turning Pytorch neural networks into Bayesian neural networks by leveraging the model definition and inference capabilities of Pyro. Our core design principle is to cleanly separate the construction of neural architecture, prior, inference distribution and likelihood, enabling a flexible workflow where each component can be exchanged independently. Defining a BNN in TyXe takes as little as 5 lines of code:

net = nn.Sequential(nn.Linear(1, 50), nn.Tanh(), nn.Linear(50, 1))
prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
likelihood = tyxe.likelihoods.HomoskedasticGaussian(scale=0.1)
inference = tyxe.guides.AutoNormal
bnn = tyxe.VariationalBNN(net, prior, likelihood, inference)

In the following, we assume that you (roughly) know what a BNN is mathematically.

Motivating example

Standard neural networks give us a single function that fits the data, but many different ones are typically plausible. With only a single fit, we don't know for what inputs the model is 'certain' (because there is training data nearby) and where it is uncertain.

ML Samples
Maximum likelihood fit Posterior samples

Implementing the former can be achieved easily in a few lines of Pytorch code, but training a BNN that gives a distribution over different fits is typically more complicated and is specifically what we aim to simplify.

Training

Constructing a BNN object has been shown in the example above. For fitting the posterior approximation, we provide a high-level .fit method similar to libraries such as scikit-learn or keras:

optim = pyro.optim.Adam({"lr": 1e-3})
bnn.fit(data_loader, optim, num_epochs)

Prediction & evaluation

Further we provide .predict and .evaluation methods, which make predictions based on multiple samples from the approximate posterior, average them based on the observation model, and return log likelihoods and an error measure:

predictions = bnn.predict(x_test, num_samples)
error, log_likelihood = bnn.evaluate(x_test, y_test, num_samples)

Local reparameterization

We implement local reparameterization for factorized Gaussians as a poutine, which reduces gradient noise during training. This means it can be enabled or disabled at both during training and prediction with a context manager:

with tyxe.poutine.local_reparameterization():
    bnn.fit(data_loader, optim, num_epochs)
    bnn.predict(x_test, num_predictions)

At the moment, this poutine does not work with the AutoNormal and AutoDiagonalNormal guides in pyro, since those draw the weights from a Delta distribution, so you need to use tyxe.guides.ParameterwiseDiagonalNormal as your guide.

MCMC

We provide a unified interface to pyro's MCMC implementations, simply use the tyxe.MCMC_BNN class instead and provide a kernel instead of the guide:

kernel = pyro.infer.mcmcm.NUTS
bnn = tyxe.MCMC_BNN(net, prior, likelihood, kernel)

Any parameters that pyro's MCMC class accepts can be passed through the keyword arguments of the .fit method.

Continual learning

Due to our design that cleanly separates the prior from guide, architecture and likelihood, it is easy to update it in a continual setting. For example, you can construct a tyxe.priors.DictPrior by extracting the distributions over all weights and biases from a ParameterwiseDiagonalNormal instance using the get_detached_distributions method and pass it to bnn.update_prior to implement Variational Continual Learning in a few lines of code. See examples/vcl.py for a basic example on split-MNIST and split-CIFAR.

Network architectures

We don't implement any layer classes. You construct your network in Pytorch and then turn it into a BNN, which makes it easy to apply the same prior and inference strategies to different neural networks.

Inference

For inference, we mainly provide an equivalent to pyro's AutoDiagonalNormal that is compatible with local reparameterization in tyxe.guides. This module also contains a few helper functions for initialization of Gaussian mean parameters, e.g. to the values of a pre-trained network. It should be possible to use any of pyro's autoguides for variational inference. See examples/resnet.py for a few options as well as initializing to pre-trained weights.

Priors

The priors can be found in tyxe.priors. We currently only support placing priors on the parameters. Through the expose and hide arguments in the init method you can specify layers, types of layers and specific parameters over which you want to place a prior. This helps, for example in learning the parameters of BatchNorm layers deterministically.

Likelihoods

tyxe.observation_models contains classes that wrap the most common torch.distributions for specifying noise models of data to

Installation

We recommend installing TyXe using conda with the provided environment.yml, which also installs all the dependencies for the examples except for Pytorch3d, which needs to be added manually. The environment assumes that you are using CUDA11.0, if this is not the case, simply change the cudatoolkit and dgl-cuda versions before running:

conda env create -f environment.yml
conda activate tyxe
pip install -e .

Citation

If you use TyXe, please consider citing:

@article{ritter2021tyxe,
  author    = {Hippolyt Ritter and
               Theofanis Karaletsos
               },
  title     = {TyXe: Pyro-based Bayesian neural nets for Pytorch},
  journal   = {International Conference on Probabilistic Programming (ProbProg)},
  volume    = {},
  pages     = {},
  year      = {2020},
  url       = {https://arxiv.org/abs/2110.00276}
}
Comments
  • Update to pyro 1.8.1

    Update to pyro 1.8.1

    I've started the process of updating the dependencies and added an example for classification.

    Things changed to enable move to pyro-ppl 1.8.1:

    • [x] fixed reference to deep_setattr in pyro.infer.autoguide.guides.deep_setattr
    • [x] updated reference for torch.cholesky to torch.linalg.cholesky
    • [x] added trange dependency for fit in svi to bring it in line with mcmc behaviour from pyro tqdm is a pyro dependency, so no new dependency is added to TyXe.
    • [x] figured out why vs in tests/test_bnn.py:69 has changed value (to look more like old vs + torch.eye)
    • [x] fixed error in example/classification.py for likelihood.data shape
    • [x] check that all examples still work (some don't work for me with pyro==1.4.0)

    bnn.py:207 triggers

    ValueError: at site "likelihood.data", invalid log_prob shape
      Expected [20], actual [20, 20]
      Try one of the following fixes:
      - enclose the batched tensor in a with pyro.plate(...): context
      - .to_event(...) the distribution being sampled
      - .permute() data dimensions
    

    I haven't found what it refers to as when i step through it, all pyro.sample statements are in plates. The relevant example in pyro eight_schools has also been updated for the newer versions of pyro to include these pyro.plate contexts.

    I'd appreciate input on the approach for the classification model as well as help with fixing these last two errors.

    opened by icfly2 12
  • ValueError: Error while computing log_prob at site 'likelihood.data'

    ValueError: Error while computing log_prob at site 'likelihood.data'

    Hi,

    First, I want to say that the work you've done is clearly amazing.

    I have playing a bit with TyXe and I am trying to convert the FC layer of a pretrained ResNet to a probabilistic layer.

    I have been implementing the code shown in this paper, unfortunately I was unsuccessful. When using a homoskedastic Gaussian for the likelihood function I get this error:

    ValueError: Error while computing log_prob at site 'likelihood.data':
    Value is not broadcastable with batch_shape+event_shape: torch.Size([128]) vs torch.Size([128, 2]).
    
    ...
    
    Sample Sites:                      
                  net.fc.weight dist             |   2 512
                               value             |   2 512
                            log_prob             |        
                    net.fc.bias dist             |   2    
                               value             |   2    
                            log_prob             |        
                likelihood.data dist             | 128   2
                               value             | 128    
    

    My code is very minimalistic:

    import torch
    import torch.nn as nn
    import albumentations as A
    
    import pyro
    from pyro.distributions import Normal
    
    from torch.utils.data import DataLoader
    
    from TyXe.tyxe.priors import IIDPrior
    from TyXe.tyxe.likelihoods import HomoskedasticGaussian
    from TyXe.tyxe import VariationalBNN
    
    from utils.loader import ImgLoader
    
    import glob
    
    transform = A.Compose(
                [
                    A.SmallestMaxSize(max_size=260),
                    A.ShiftScaleRotate(shift_limit=0.05, scale_limit=0.05, rotate_limit=15, p=0.5),
                    A.RandomCrop(height=224, width=224),
                    A.RGBShift(r_shift_limit=15, g_shift_limit=15, b_shift_limit=15, p=0.5),
                    A.RandomBrightnessContrast(p=0.5),
                    A.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)),
                ]
            )
    
    # Prepare the dataset
    data_path = "/Data/train"
    list_images = glob.glob(data_path + "/*.jpg")
    
    trainset = ImgLoader(list_images, transform=transform)
    
    trainLoader = DataLoader(trainset,
                            batch_size = 128, 
                            shuffle=True)
    
    # Load pre-trained model ResNet and add FC layers on top
    torch.hub._validate_not_a_forked_repo=lambda a,b,c: True # Fix bug for Pytorch 1.9
    NN = torch.hub.load('pytorch/vision:v0.9.0', 'resnet18', pretrained=True)
    NN.to('cpu')
    
    # Freeze the layers
    for param in NN.parameters():
        param.requires_grad = False
    
    NN.fc = nn.Linear(512,2)
    
    ll_prior = IIDPrior(Normal(0, 1), expose_all=False, expose_modules=[NN.fc])
    likelihood = HomoskedasticGaussian(len(list_images), event_dim=2, scale=1)
    lr_guide = pyro.infer.autoguide.AutoLowRankMultivariateNormal
    bnn = VariationalBNN(NN, ll_prior, likelihood, lr_guide)
    optim = pyro.optim.Adam({"lr": 1e-3})
    
    # fit the model
    bnn.fit(trainLoader, optim, 5)
    

    Thanks for your help already!

    opened by BenCretois 5
  • Development status

    Development status

    TyXe is recomended to be used instead of Pyro HiddenLayers, but it looks like it is pinned to an older version of Pyro. What is the development status of this library?

    I'd be happy to contribute, there are some marked first issues, but before adding work I'd like to find out what the original authors have in mind with this package.

    opened by icfly2 3
  • turning UNet into Bayesian UNet

    turning UNet into Bayesian UNet

    Hi,

    I am trying to use your library to turn UNet into a Bayesian Unet. I paste the code below: in the implementation UNet works as a pixel-to-pixel translator for 3D data. The code follows your regression example (as I am also doing regression but for higher dimensional data).

    When I run the code I got a run-time error: ValueError: Expected parameter scale (Tensor of shape (4, 1, 32, 32, 16)) of distribution Normal(loc: torch.Size([4, 1, 32, 32, 16]), scale: torch.Size([4, 1, 32, 32, 16])) to satisfy the constraint GreaterThan(lower_bound=0.0), but found invalid values...

    I expect that the problem is with a wrong selection of the prior and/or guide. I would appreciate any suggestion which will make the model to learn.

    Regards, Zbisław

    The code:

    from functools import partial

    import torch import torch.nn as nn import torch.utils.data as data

    import pyro import pyro.distributions as dist

    import tyxe

    def double_convolution(in_channels, out_channels): """ In the original paper implementation, the convolution operations were not padded but we are padding them here. This is because, we need the output result size to be same as input size. """ conv_op = nn.Sequential( nn.Conv3d(in_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True), nn.Conv3d(out_channels, out_channels, kernel_size=3, padding=1), nn.ReLU(inplace=True) ) return conv_op

    class UNet(nn.Module): def init(self, num_classes): super(UNet, self).init()

        self.max_pool3d = nn.MaxPool3d(kernel_size=2, stride=2)
    
        # contracting path
        # each convolution is applied twice
        self.down_convolution_1 = double_convolution(3, 64)
        self.down_convolution_2 = double_convolution(64, 128)
        self.down_convolution_3 = double_convolution(128, 256)
        self.down_convolution_4 = double_convolution(256, 512)
        self.down_convolution_5 = double_convolution(512, 1024)
    
        # expanding path
        self.up_transpose_1 = nn.ConvTranspose3d(
            in_channels=1024, out_channels=512,
            kernel_size=2, 
            stride=2)
        # below, `in_channels` again becomes 1024 as we are concatinating
        self.up_convolution_1 = double_convolution(1024, 512)
        self.up_transpose_2 = nn.ConvTranspose3d(
            in_channels=512, out_channels=256,
            kernel_size=2, 
            stride=2)
        self.up_convolution_2 = double_convolution(512, 256)
        self.up_transpose_3 = nn.ConvTranspose3d(
            in_channels=256, out_channels=128,
            kernel_size=2, 
            stride=2)
        self.up_convolution_3 = double_convolution(256, 128)
        self.up_transpose_4 = nn.ConvTranspose3d(
            in_channels=128, out_channels=64,
            kernel_size=2, 
            stride=2)
        self.up_convolution_4 = double_convolution(128, 64)
    
        # output => increase the `out_channels` as per the number of classes
        self.out = nn.Conv3d(
            in_channels=64, out_channels=num_classes, 
            kernel_size=1
        ) 
    
    def forward(self, x):
        down_1 = self.down_convolution_1(x)
        down_2 = self.max_pool3d(down_1)
        down_3 = self.down_convolution_2(down_2)
        down_4 = self.max_pool3d(down_3)
        down_5 = self.down_convolution_3(down_4)
        down_6 = self.max_pool3d(down_5)
        down_7 = self.down_convolution_4(down_6)
        #down_8 = self.max_pool3d(down_7)
        #down_9 = self.down_convolution_5(down_8)        
        
        #up_1 = self.up_transpose_1(down_9)
        #x = self.up_convolution_1(torch.cat([down_7, up_1], 1))
    
        #up_2 = self.up_transpose_2(x)
        up_2 = self.up_transpose_2(down_7)
        x = self.up_convolution_2(torch.cat([down_5, up_2], 1))
    
        up_3 = self.up_transpose_3(x)
        x = self.up_convolution_3(torch.cat([down_3, up_3], 1))
    
        up_4 = self.up_transpose_4(x)
        x = self.up_convolution_4(torch.cat([down_1, up_4], 1))
    
        out = self.out(x)
        return out
    

    ################################################################################

    if name == 'main':

    pyro.set_rng_seed(42)
    
    x = torch.randn((4,3,32,32,32)) * 0.5 + 0.5
    y = torch.mean(x,1,keepdim=True) + torch.randn((4,1,32,32,32))*0.01
    
    for _ in range(10):
        x2 = torch.randn((4,3,32,32,32)) * 0.5 + 0.5
        x = torch.cat([x, x2])
    
        y2 = torch.mean(x2,1,keepdim=True) + torch.randn((4,1,32,32,32))*0.01
        y = torch.cat([y, y2])
    
    batchSize = 4
    dataset = data.TensorDataset(x, y)
    loader = data.DataLoader(dataset, batch_size=batchSize)
    
    
    
    
    net = UNet(1)
    prior = tyxe.priors.IIDPrior(dist.Normal(0, 1))
    obs_model = tyxe.likelihoods.HeteroskedasticGaussian((4,1,32,32,32))
    guide = partial(tyxe.guides.AutoNormal, init_scale=0.01)
    bnn = tyxe.VariationalBNN(net, prior, obs_model, guide)
    
    
    pyro.clear_param_store()
    optim = pyro.optim.Adam({"lr": 1e-4})
    elbos = []
    def callback(bnn, i, e):
        elbos.append(e)
        
    with tyxe.poutine.local_reparameterization():
        bnn.fit(loader, optim, 10000, callback)
    
    opened by taborzbislaw 2
  • Unable to save and load models

    Unable to save and load models

    Is there a way to save and load models, especially the ones using MCMC? I was hoping to use something along the lines of pyro.get_param_store but it does not work with MCMC apparently.

    opened by canbooo 2
  • Important data leakage in resnet example.

    Important data leakage in resnet example.

    Hi !

    First thank you very much for this repo, it is very helpful for people who are new to BNNs like me.

    I have started to use TyXe for a convolutional BNN, starting from your resnet.py example. After getting some unexpected behavior during training I have noticed that in the callback function the network was not set into evalutation mode. It is resulting in an important data leakage by training the network on the test dataset as well if I am correct.

    I would recommend to start and finish the callback function with respectively b.eval() and b.train().

    Hope this helps, Regards!

    opened by Cam-B04 1
  • Correct issue #21

    Correct issue #21

    Setting the model into evaluation mode during callback and consequently avoid data leakage.

    (Pull request might be overkill for this little change, this is my first one so do not hesitate to tell me.)

    Regards!

    opened by Cam-B04 0
  • DGL deprecation warnings

    DGL deprecation warnings

    deprecation warnings need to be addressed.

        features = torch.FloatTensor(data.features)
        labels = torch.LongTensor(data.labels)
        train_mask = torch.BoolTensor(data.train_mask)
        test_mask = torch.BoolTensor(data.test_mask)
        val_mask = torch.BoolTensor(data.val_mask)
        g = dgl.from_networkx(data.graph)
    

    I'll make a PR after pyro upgrade

    opened by icfly2 2
  • Errors when trying to load a VariationalBNN

    Errors when trying to load a VariationalBNN

    Hi all, I'm new to TyXe, but I'm experimenting an issue when I'm trying to load a (previously) trained model from the disk.

    To be more precise, the returned error is as in the following:

    raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for VariationalBNN: Unexpected key(s) in state_dict: net_guide.rnn.weight_ih_l0.loc_unconstrained etc.

    In particular, to save the model, I use a code like this:

    pyro.get_param_store().save(os.path.join(output_dir, "param_store.pt")) torch.save(model.state_dict(), os.path.join(output_dir, "best_mode.pt"))

    To load the model (defined as tyxe.VariationalBNN(net, prior, likelihood, guide)) instead:

    pyro.clear_param_store() model.load_state_dict(torch.load(os.path.join(save_model_path, "best_model.pt"))) pyro.get_param_store().load(os.path.join(save_model_path, "param_store.pt"))

    Where is the error?

    Thank you so much.

    opened by francescofolino 4
  • Measuring Uncertainty - epistemic & aleatoric

    Measuring Uncertainty - epistemic & aleatoric

    Firstly, thanks for the amazing repo. I'm a complete beginner and this repo really broke BNNs down simply in a very hands-on manner.

    Could you guide me on how to calculate the uncertainty and break it into its epistemic & aleatoric components using TyXe?

    Thanks!

    opened by Saadmaghani 1
  • Implementing Radial BNN

    Implementing Radial BNN

    Hi,

    I’m trying to fit a radial BNN posterior variational approximation as per this paper.

    However, since I’ll be training a BNN, I don’t want to have to write a custom guide and define this variational approximation for all of my layers, and so was trying to implement a custom AutoGuide which automatically puts a radial BNN approximation on all of my weights.

    The radial approximation is defined as follows: image where I just need to sample all epsilon_MFVI from an independent standard normal distribution, normalize them, and multiply them by r, which is a scalar sampled from a standard normal.

    How could I go about implementing this in TyXe? Is there a smarter way of implementing this variational approximation?

    P.S. Big fan of this project!

    Thanks in advance.

    opened by silasbrack 2
Owner
null
This project is based on RIFE and aims to make RIFE more practical for users by adding various features and design new models

This project is based on RIFE and aims to make RIFE more practical for users by adding various features and design new models. Because improving the PSNR index is not compatible with subjective effects, we hope this part of work and our academic research are independent of each other.

hzwer 190 Jan 8, 2023
Users can free try their models on SIDD dataset based on this code

SIDD benchmark 1 Train python train.py If you want to train your network, just modify the yaml in the options folder. 2 Validation python validation.p

Yuzhi ZHAO 2 May 20, 2022
BisQue is a web-based platform designed to provide researchers with organizational and quantitative analysis tools for 5D image data. Users can extend BisQue by implementing containerized ML workflows.

Overview BisQue is a web-based platform specifically designed to provide researchers with organizational and quantitative analysis tools for up to 5D

Vision Research Lab @ UCSB 26 Nov 29, 2022
MMdnn is a set of tools to help users inter-operate among different deep learning frameworks. E.g. model conversion and visualization. Convert models between Caffe, Keras, MXNet, Tensorflow, CNTK, PyTorch Onnx and CoreML.

MMdnn MMdnn is a comprehensive and cross-framework tool to convert, visualize and diagnose deep learning (DL) models. The "MM" stands for model manage

Microsoft 5.7k Jan 9, 2023
Trajectory Extraction of road users via Traffic Camera

Traffic Monitoring Citation The associated paper for this project will be published here as soon as possible. When using this software, please cite th

Julian Strosahl 14 Dec 17, 2022
Monify: an Expense tracker Program implemented in a Graphical User Interface that allows users to keep track of their expenses

?? MONIFY (EXPENSE TRACKER PRO) ?? Description Monify is an Expense tracker Program implemented in a Graphical User Interface allows users to add inco

Moyosore Weke 1 Dec 14, 2021
Code for Private Recommender Systems: How Can Users Build Their Own Fair Recommender Systems without Log Data? (SDM 2022)

Private Recommender Systems: How Can Users Build Their Own Fair Recommender Systems without Log Data? (SDM 2022) We consider how a user of a web servi

joisino 20 Aug 21, 2022
NuPIC Studio is an all­-in-­one tool that allows users create a HTM neural network from scratch

NuPIC Studio is an all­-in-­one tool that allows users create a HTM neural network from scratch, train it, collect statistics, and share it among the members of the community. It is not just a visualization tool but an HTM builder, debugger and laboratory for experiments. It is ideal for newbies with little intimacy with NuPIC code as well as experts that wish a better productivity. Among its features and advantages:

HTM Community 93 Sep 30, 2022
TargetAllDomainObjects - A python wrapper to run a command on against all users/computers/DCs of a Windows Domain

TargetAllDomainObjects A python wrapper to run a command on against all users/co

Podalirius 19 Dec 13, 2022
DexterRedTool - Dexter's Red Team Tool that creates cronjob/task scheduler to consistently creates users

DexterRedTool Author: Dexter Delandro CSEC 473 - Spring 2022 This tool persisten

null 2 Feb 16, 2022
Pull sensitive data from users on windows including discord tokens and chrome data.

⭐ For a ?? Pegasus Pull sensitive data from users on windows including discord tokens and chrome data. Features ?? Discord tokens ?? Geolocation data

Addi 44 Dec 31, 2022
PyTorch implementation of Algorithm 1 of "On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models"

Code for On the Anatomy of MCMC-Based Maximum Likelihood Learning of Energy-Based Models This repository will reproduce the main results from our pape

Mitch Hill 32 Nov 25, 2022
A general framework for deep learning experiments under PyTorch based on pytorch-lightning

torchx Torchx is a general framework for deep learning experiments under PyTorch based on pytorch-lightning. TODO list gan-like training wrapper text

Yingtian Liu 6 Mar 17, 2022
RETRO-pytorch - Implementation of RETRO, Deepmind's Retrieval based Attention net, in Pytorch

RETRO - Pytorch (wip) Implementation of RETRO, Deepmind's Retrieval based Attent

Phil Wang 556 Jan 4, 2023
Deep Image Search is an AI-based image search engine that includes deep transfor learning features Extraction and tree-based vectorized search.

Deep Image Search - AI-Based Image Search Engine Deep Image Search is an AI-based image search engine that includes deep transfer learning features Ex

null 139 Jan 1, 2023
Alex Pashevich 62 Dec 24, 2022
DIT is a DTLS MitM proxy implemented in Python 3. It can intercept, manipulate and suppress datagrams between two DTLS endpoints and supports psk-based and certificate-based authentication schemes (RSA + ECC).

DIT - DTLS Interception Tool DIT is a MitM proxy tool to intercept DTLS traffic. It can intercept, manipulate and/or suppress DTLS datagrams between t

null 52 Nov 30, 2022
Calculates carbon footprint based on fuel mix and discharge profile at the utility selected. Can create graphs and tabular output for fuel mix based on input file of series of power drawn over a period of time.

carbon-footprint-calculator Conda distribution ~/anaconda3/bin/conda install anaconda-client conda-build ~/anaconda3/bin/conda config --set anaconda_u

Seattle university Renewable energy research 7 Sep 26, 2022