Storchastic is a PyTorch library for stochastic gradient estimation in Deep Learning


Stochastic Deep Learning for Pytorch

Documentation on Read the Docs. Storchastic is a PyTorch library for stochastic gradient estimation in Deep Learning [1]. Many state of the art deep learning models use gradient estimation, in particular within the fields of Variational Inference and Reinforcement Learning. While PyTorch computes gradients of deterministic computation graphs automatically, it will not estimate gradients on stochastic computation graphs [2].

With Storchastic, you can easily define any stochastic deep learning model and let it estimate the gradients for you. Storchastic provides a large range of gradient estimation methods that you can plug and play, to figure out which one works best for your problem. Storchastic provides automatic broadcasting of sampled batch dimensions, which increases code readability and allows implementing complex models with ease.

When dealing with continuous random variables and differentiable functions, the popular reparameterization method [3] is usually very effective. However, this method is not applicable when dealing with discrete random variables or non-differentiable functions. This is why Storchastic has a focus on gradient estimators for discrete random variables, non-differentiable functions and sequence models.

Example: Discrete Variational Auto-Encoder


pip install storchastic

Requires Pytorch 1.5 (older versions will not do!) and Pyro. The code is build on Python 3.7. The master branch works with PyTorch 1.7, but the version on pip is not compatible. Binaries will be updated soon.


Feel free to create an issue if an estimator is missing here.

  • Reparameterization [1, 3]
  • Score Function (REINFORCE) with Moving Average baseline [1, 4]
  • Score Function with Batch Average Baseline [5, 6]
  • Expected value for enumerable distributions
  • (Straight through) Gumbel Softmax [7, 8]
  • LAX, RELAX [9]
  • REBAR [10]
  • REINFORCE Without Replacement [6]
  • Unordered Set Estimator [13]

In development

  • Memory Augmented Policy Optimization [11]
  • Rao-Blackwellized REINFORCE [12]


  • Measure valued derivatives [1, 14]
  • ARM [15]
  • Automatic Credit Assignment [16]
  • ...


    I use python=3.8 and pytorch=1.8.1 and test the Error happens in the following line in log_probs = tensor.distribution.log_prob(tensor) error is ValueError: The value argument to log_prob must be a Tensor

    Seems .log_prob() wants a tensor object as input. When I print the 'tensor', it outputs:

    x: Stochastic tensor([[0., 0., 0., 0.],
            [0., 0., 0., 1.],
            [0., 0., 1., 0.],
            [0., 0., 1., 1.],
            [0., 1., 0., 0.],
            [0., 1., 0., 1.],
            [0., 1., 1., 0.],
            [0., 1., 1., 1.],
            [1., 0., 0., 0.],
            [1., 0., 0., 1.],
            [1., 0., 1., 0.],
            [1., 0., 1., 1.],
            [1., 1., 0., 0.],
            [1., 1., 0., 1.],
            [1., 1., 1., 0.],
            [1., 1., 1., 1.]]) Batch links: [('x', 16, tensor(0.0625))]

    Could you offer some solutions? Thank you

    opened by zhaoguangyuan123 8
    Running backward() when there are a lot of cost nodes causes a recursion-limit-reached exception to be thrown.

    I'm able to consistently generate this exception running the following:

    import storch
    import torch
    from torch.distributions import Bernoulli
    from storch.method import GumbelSoftmax
    p = torch.tensor(0.5, requires_grad=True)
    for i in range(1000):
        sample = GumbelSoftmax(f"sample_{i}")(Bernoulli(p))
        storch.add_cost(sample, f"cost_{i}")

    which produces the following trace:

    Traceback (most recent call last):
      File "...", line 15, in <module>
      File ".../lib/python3.8/site-packages/storch/", line 217, in backward
      File ".../lib/python3.8/site-packages/storch/", line 401, in _clean
      File ".../lib/python3.8/site-packages/storch/", line 401, in _clean
      File ".../lib/python3.8/site-packages/storch/", line 401, in _clean
      [Previous line repeated 994 more times]
      File ".../lib/python3.8/site-packages/storch/", line 399, in _clean
    RecursionError: maximum recursion depth exceeded
    opened by csmith49 4
    I'm currently working with CNNs, which allow input x with shape x.shape -> [N, C, H, W].

    However, when I set N as the batch dim (use storch.denote_independent(x, 0, "data")) and perform RELAX("z", n_samples=1) sampling (it adds another plate), the sampled tensor becomes z.shape -> [Z, N, C, H, W] and is not compatible with Conv2D any more.

    Is there any way to reshape z to be compatible with Conv2D?

    Minimal reproduction code:

    class RelaxSample(nn.Module):
        def __init__(self, k: int, c: int):
            self._k = k
            self._wv = nn.Parameter(torch.empty(k, c))
            self._method = RELAX("z", n_samples=1, in_dim=[k], rebar=True)
        def forward(self, latent):
            # [n, c, h, w]
            n, c, h, w = latent.shape
            # [n, h, w, c]
            q = latent.permute(0, 2, 3, 1)
            # [n, h, w, c], [k, c] -> [n, h, w, k]
            logit = torch.einsum("nhwc,kc->nhwk", q, self._wv)
            varPosterior = Categorical(logits=logit)
            # [z, n, h, w, k]
            z = self._method(varPosterior)
            # [z, n, c, h, w]
            z = torch.einsum("znhwk,kc->znhwc", z, self._wv).permute(0, 1, 4, 2, 3)
            # ********* z is not compatible with Conv2D *********
            return z
    opened by xiaosu-zhu 3
    RELAX and REBAR gradient estimators don't appear to be usable on Bernoulli random variables - some operation expects them to have the mc_sample attribute, which is missing.

    I can reproduce this result with the following:

    import storch
    import torch
    from torch.distributions import Bernoulli
    from storch.method import RELAX
    p = torch.tensor(0.5, requires_grad=True)
    d = Bernoulli(p)
    sample = RELAX("sample")(d)
    storch.add_cost(sample, "cost")

    which produces the following trace:

    Traceback (most recent call last):
      File "...", line 13, in <module>
        sample = RELAX("sample")(d)
      File ".../lib/python3.8/site-packages/storch/method/", line 210, in __init__
        super().__init__(plate_name, sampling_method.set_mc_sample(self.mc_sample))
      File ".../lib/python3.8/site-packages/torch/nn/modules/", line 771, in __getattr__
        raise ModuleAttributeError("'{}' object has no attribute '{}'".format(
    torch.nn.modules.module.ModuleAttributeError: 'RELAX' object has no attribute 'mc_sample'
    opened by csmith49 2
    Often, samples from a distribution are conditionally independent: for example in LDA. While it is possible to create a list of K conditionally independent variables by looping over sample steps, this is not vectorized and thus not efficient.

    Plating could possibly be done by calling storch.plate() on a storch tensor.

    opened by HEmile 2
    Running either of the examples with Bernoulli distributions (examples/ and examples/ results in a ValueError, specifically:

    ValueError: Input arguments must all be instances of numbers.Number or torch.tensor.


    Running Python 3.8.5, with Pyro 1.5.0 and Torch 1.7.0.


    Running python3 examples/ from the root of the repo produces (with better_exceptions enabled):

    Traceback (most recent call last):
      File "examples/", line 29, in <module>
        │          └ <class 'storch.method.method.Expect'>
        └ <function experiment at 0x7fbf47b3a0d0>
      File "examples/", line 19, in experiment
        x = method(b)
            │      └ Bernoulli(probs: torch.Size([4]), logits: torch.Size([4]))
            └ Expect(
      (sampling_method): Enumerate()
      File ".../python3.8/site-packages/torch/nn/modules/", line 727, in _call_impl
        result = self.forward(*input, **kwargs)
                 │             │        └ {}
                 │             └ (Bernoulli(probs: torch.Size([4]), logits: torch.Size([4])),)
                 └ Expect(
      (sampling_method): Enumerate()
      File ".../python3.8/site-packages/storch/method/", line 52, in forward
        return self.sample(distr)
               │           └ Bernoulli(probs: torch.Size([4]), logits: torch.Size([4]))
               └ Expect(
      (sampling_method): Enumerate()
      File ".../python3.8/site-packages/storch/method/", line 114, in sample
        batch_weighting = self.sampling_method.plate_weighting(s_tensor, plate)
                          │                                    │         └ ('x', 16, tensor(0.0625))
                          │                                    └ tensor([[0., 0., 0., 0.],
            [0., 0., 0., 1.],
            [0., 0., 1., 0.],
            [0., 0., 1., 1.],
            [0., 1., 0., 0.]...
                          └ Expect(
      (sampling_method): Enumerate()
      File ".../python3.8/site-packages/storch/sampling/", line 82, in plate_weighting
        log_probs = tensor.distribution.log_prob(tensor)
                    │                            └ tensor([[0., 0., 0., 0.],
            [0., 0., 0., 1.],
            [0., 0., 1., 0.],
            [0., 0., 1., 1.],
            [0., 1., 0., 0.]...
                    └ tensor([[0., 0., 0., 0.],
            [0., 0., 0., 1.],
            [0., 0., 1., 0.],
            [0., 0., 1., 1.],
            [0., 1., 0., 0.]...
      File ".../python3.8/site-packages/torch/distributions/", line 94, in log_prob
        logits, value = broadcast_all(self.logits, value)
                │       │             │            └ tensor([[0., 0., 0., 0.],
            [0., 0., 0., 1.],
            [0., 0., 1., 0.],
            [0., 0., 1., 1.],
            [0., 1., 0., 0.]...
                │       │             └ Bernoulli(probs: torch.Size([4]), logits: torch.Size([4]))
                │       └ <function broadcast_all at 0x7fbe57163430>
                └ tensor([[0., 0., 0., 0.],
            [0., 0., 0., 1.],
            [0., 0., 1., 0.],
            [0., 0., 1., 1.],
            [0., 1., 0., 0.]...
      File ".../python3.8/site-packages/torch/distributions/", line 24, in broadcast_all
        raise ValueError('Input arguments must all be instances of numbers.Number or torch.tensor.')
    ValueError: Input arguments must all be instances of numbers.Number or torch.tensor.
    opened by csmith49 1
    We should abstract sampling procedures so that they are no longer part of a Method's inheritance tree, but rather as a separate object to plug and play. This would make it easier to combine various estimators with various sampling methods, whether biased or not.

    Example applications:

    • uses beam search with a biased REINFORCE method. Allowing to use any estimator with this beam search would help modularity and debugging.
    • Usually, normal MC sampling is enough, but eg when decoding text, we have to deal with eos's that complicate the sampling procedure.
    • Sampling without replacement is applicable to any unbiased estimator when correcting for the bias through the weighting. By making sampling modular, we can more easily test combinations of estimators.
    opened by HEmile 1
    Methods like reparameterization do not require computing an estimator. We could skip some transposes by manually checking for this using a function in Method

    opened by HEmile 1
    Tensors resulting from sampling can not correspond with its plates. Solution: Create a "sampling" context in which the wrappers are temporarily disabled so that sampling is no longer a problem.

    opened by HEmile 1
    Implementing Expect is not as easy as I hoped. Unfortunately, in discrete VAEs, for example, you pass logits with shape [batch, latents, categories]. enumerate_support() does not recognize that batch is independent, but the different latents are dependent! It only enumerates the support over the categories dimension, but it should also enumerate over the latents dimension. To do this automatically however requires independency assumptions, which probably brings us back to Issue #4.

    opened by HEmile 1
    The unordered set estimator is mostly a sampling strategy. It probably also provides variance reduction for other estimators than the score function. Is it possible to make sampling modular from the rest of the estimators? In this case, reweighting the costs is probably enough...

    It won't actually work with reparameterization: The reparameterization will not properly weight the different samples. Should rethink this.

    enhancement help wanted 
    opened by HEmile 1
    Bumps certifi from 2020.12.5 to 2022.12.7.


    Dependabot compatibility score

    Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting @dependabot rebase.

    Dependabot commands and options

    You can trigger Dependabot actions by commenting on this PR:

    • @dependabot rebase will rebase this PR
    • @dependabot recreate will recreate this PR, overwriting any edits that have been made to it
    • @dependabot merge will merge this PR after your CI passes on it
    • @dependabot squash and merge will squash and merge this PR after your CI passes on it
    • @dependabot cancel merge will cancel a previously requested merge and block automerging
    • @dependabot reopen will reopen this PR if it is closed
    • @dependabot close will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually
    • @dependabot ignore this major version will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this minor version will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself)
    • @dependabot ignore this dependency will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
    • @dependabot use these labels will set the current labels as the default for future PRs for this repo and language
    • @dependabot use these reviewers will set the current reviewers as the default for future PRs for this repo and language
    • @dependabot use these assignees will set the current assignees as the default for future PRs for this repo and language
    • @dependabot use this milestone will set the current milestone as the default for future PRs for this repo and language

    You can disable automated security fix PRs for this repo from the Security Alerts page.

    opened by dependabot[bot] 0
    Storchastic uses a rather intricate system for batching over multiple dimensions, but it's rather buggy and hard to work with for end users. Recently PyTorch 1.12 introduced torchdim with first-class dimension objects that should serve the same purpose as plates in Storchastic. We will likely have much cleaner, easier to read, write and debug and faster code by adopting this new standard. See

    It implements

    • Implicit batching: Two batch dimensions are joined together, just like in Storchastic.
    • Mixed named tensors: Only batch dimensions need to be named, event dimensions in Storchastic can just be numeric
    opened by HEmile 0
    Hi, @HEmile , I test the example in examples/vae/, but find the gumbel softmax performs much better than rebar, relax and reinforce (testing loss after 10 epochs: 98 for gumbel, 165 for rebar, 208 for relax, 209 for reinforce). The performance is not that consistent with the results in your guide, where gumbel softmax is slightly worse than reinforce in terms of training loss. Is there anything wrong with the example? I don't change any parameters. Looking forward to your response.

    opened by fnzhan 4
    Hi! I need to concatenate samples from these two methods into one tensor, so that I can sample from a continuous distribution using reparameterization, and sample from a discrete distribution using UnorderedSetEstimator. Is this functionality something that can be built in? It seems that this is not possible given the current iteration (see below for example error). Thank you!

    import storch
    import torch
    import torch.distributions as td
    method1 = storch.method.Reparameterization
    method2 = storch.method.UnorderedSetEstimator
    method1 = method1(plate_name="1",n_samples=25)
    method2 = method2(plate_name="1",k=25)
    p1 = td.Independent(td.Normal(loc=torch.zeros([1000,2]),scale=torch.ones([1000,2])),0)
    p2 = td.Independent(td.OneHotCategorical(probs=torch.zeros([1000,3]).uniform_()),0)
    samp1 = method1(p1)
    samp2 = method2(p2)
    samp1.shape # torch.Size([25, 1000, 2])
    samp2.shape # torch.Size([25, 1000, 3])[samp1,samp2],2)
    ValueError: Received a plate with name 1 that is not also an AncestralPlate.
    opened by DavidKLim 1
    Implement: For multivariate Bernoulli, Sample a single normal sample, then for each dimension, use this sample and flip the corresponding dimension.

    Run the normal sample and the flipped samples. Weight the original sample with 1. The multiplicative is: the sum of the parameters for the normal sample. These are negated if the corresponding dimension is 1. Then for the flipped samples, use -parameter if it flipped to 0, otherwise use just the parameter (ie, not negated).

    To make the zeroth-order correct, use importance sampling for the flipped samples. In the multiplicative estimator, divide again by this importance sampling to ensure correctness.

    This doesn't use a baseline.

    opened by HEmile 1
Emile van Krieken
PhD AI student, into combining Knowledge Representation with Machine Learning.
Emile van Krieken
