functorch is a prototype of JAX-like composable function transforms for PyTorch.

Overview

functorch

Why functorch? | Install guide | Transformations | Future Plans

This library is currently under heavy development - if you have suggestions on the API or use-cases you'd like to be covered, please open an github issue or reach out. We'd love to hear about how you're using the library.

functorch is a prototype of JAX-like composable FUNCtion transforms for pyTORCH.

It aims to provide composable vmap and grad transforms that work with PyTorch modules and PyTorch autograd with good eager-mode performance. Because this project requires some investment, we'd love to hear from and work with early adopters to shape the design. Please reach out on the issue tracker if you're interested in using this for your project.

In addition, there is experimental functionality to trace through these transformations using FX in order to capture the results of these transforms ahead of time. This would allow us to compile the results of vmap or grad to improve performance.

Why composable function transforms?

There are a number of use cases that are tricky to do in PyTorch today:

  • computing per-sample-gradients (or other per-sample quantities)
  • running ensembles of models on a single machine
  • efficiently batching together tasks in the inner-loop of MAML
  • efficiently computing Jacobians and Hessians
  • efficiently computing batched Jacobians and Hessians

Composing vmap, grad, and vjp transforms allows us to express the above without designing a separate subsystem for each. This idea of composable function transforms comes from the JAX framework.

Install

Colab

Follow the instructions in this Colab notebook

Binaries

First, set up an environment. We will be installing a nightly PyTorch binary as well as functorch. If you're using conda, create a conda environment:

conda create --name functorch
conda activate functorch

If you wish to use venv instead:

python -m venv functorch-env
source functorch-env/bin/activate

Next, install one of the following following PyTorch nightly binaries.

# For CUDA 10.2
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu102/torch_nightly.html
# For CUDA 11.1
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cu111/torch_nightly.html
# For CPU-only build
pip install --pre torch -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html

If you already have a nightly of PyTorch installed and wanted to upgrade it (recommended!), append --upgrade to one of those commands.

Install functorch:

pip install ninja  # Makes the build go faster
pip install --user "git+https://github.com/facebookresearch/functorch.git"

Run a quick sanity check in python:

>>> import torch
>>> from functorch import vmap
>>> x = torch.randn(3)
>>> y = vmap(torch.sin)(x)
>>> assert torch.allclose(y, x.sin())

From Source

functorch is a PyTorch C++ Extension module. To install,

  • Install PyTorch from source. functorch usually runs on the latest development version of PyTorch.
  • Run python setup.py install. You can use DEBUG=1 to compile in debug mode.

Then, try to run some tests to make sure all is OK:

pytest test/test_vmap.py -v
pytest test/test_eager_transforms.py -v

What are the transforms?

Right now, we support the following transforms:

  • grad, vjp, jacrev
  • vmap

Furthermore, we have some utilities for working with PyTorch modules.

  • make_functional(model)
  • make_functional_with_buffers(model)

vmap

Note: vmap imposes restrictions on the code that it can be used on. For more details, please read its docstring.

vmap(func)(*inputs) is a transform that adds a dimension to all Tensor operations in func. vmap(func) returns a few function that maps func over some dimension (default: 0) of each Tensor in inputs.

vmap is useful for hiding batch dimensions: one can write a function func that runs on examples and then lift it to a function that can take batches of examples with vmap(func), leading to a simpler modeling experience:

>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>> weights = torch.randn(feature_size, requires_grad=True)
>>>
>>> def model(feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return feature_vec.dot(weights).relu()
>>>
>>> examples = torch.randn(batch_size, feature_size)
>>> result = vmap(model)(examples)

grad

grad(func)(*inputs) assumes func returns a single-element Tensor. It compute the gradients of the output of func w.r.t. to inputs[0].

>>> from functorch import grad
>>> x = torch.randn([])
>>> cos_x = grad(lambda x: torch.sin(x))(x)
>>> assert torch.allclose(cos_x, x.cos())
>>>
>>> # Second-order gradients
>>> neg_sin_x = grad(grad(lambda x: torch.sin(x)))(x)
>>> assert torch.allclose(neg_sin_x, -x.sin())

When composed with vmap, grad can be used to compute per-sample-gradients:

>>> from functorch import vmap
>>> batch_size, feature_size = 3, 5
>>>
>>> def model(weights,feature_vec):
>>>     # Very simple linear model with activation
>>>     assert feature_vec.dim() == 1
>>>     return feature_vec.dot(weights).relu()
>>>
>>> def compute_loss(weights, example, target):
>>>     y = model(weights, example)
>>>     return ((y - target) ** 2).mean()  # MSELoss
>>>
>>> weights = torch.randn(feature_size, requires_grad=True)
>>> examples = torch.randn(batch_size, feature_size)
>>> targets = torch.randn(batch_size)
>>> inputs = (weights,examples, targets)
>>> grad_weight_per_example = vmap(grad(compute_loss), in_dims=(None, 0, 0))(*inputs)

vjp and jacrev

>>> from functorch import vjp
>>> outputs, vjp_fn = vjp(func, inputs); vjps = vjp_fn(*cotangents)

The vjp transform applies func to inputs and returns a new function that computes vjps given some cotangents Tensors.

>>> from functorch import jacrev
>>> x = torch.randn(5)
>>> jacobian = jacrev(torch.sin)(x)
>>> expected = torch.diag(x)
>>> assert torch.allclose(jacobian, expected)

Use jacrev to compute the jacobian. This can be composed with vmap to produce batched jacobians:

>>> x = torch.randn(64, 5)
>>> jacobian = vmap(jacrev(torch.sin))(x)
>>> assert jacobian.shape == (64, 5, 5)

jacrev can be composed with itself to produce hessians:

>>> def f(x):
>>>   return x.sin().sum()
>>>
>>> x = torch.randn(5)
>>> hessian = jacrev(jacrev(f))(x)

Tracing through the transformations

We can also trace through these transformations in order to capture the results as new code using make_fx. There is also experimental integration with the NNC compiler (only works on CPU for now!).

>>> from functorch import make_fx, grad
>>> def f(x):
>>>     return torch.sin(x).sum()
>>> x = torch.randn(100)
>>> grad_f = make_fx(grad(f))(x)
>>> print(grad_f.code)

def forward(self, x_1):
    sin = torch.ops.aten.sin(x_1)
    sum_1 = torch.ops.aten.sum(sin, None);  sin = None
    cos = torch.ops.aten.cos(x_1);  x_1 = None
    _tensor_constant0 = self._tensor_constant0
    mul = torch.ops.aten.mul(_tensor_constant0, cos);  _tensor_constant0 = cos = None
    return mul

We can also try compiling it with NNC (even more experimental)!.

>>> from functorch import nnc_jit
>>> jit_f = nnc_jit(grad(f))

Check examples/nnc for some example benchmarks.

Working with NN modules: make_functional and friends

Sometimes you may want to perform a transform with respect to the parameters and/or buffers of an nn.Module. This can happen for example in:

  • model ensembling, where all of your weights and buffers have an additional dimension
  • per-sample-gradient computation where you want to compute per-sample-grads of the loss with respect to the model parameters

Our solution to this right now is an API that, given an nn.Module, creates a stateless version of it that can be called like a function.

  • make_functional(model) returns a functional version of model and the model.parameters()
  • make_functional_with_buffers(model) returns a functional version of model and the model.parameters() and model.buffers().

Here's an example where we compute per-sample-gradients using an nn.Linear layer:

import torch
from functorch import make_functional, vmap, grad

model = torch.nn.Linear(3, 3)
data = torch.randn(64, 3)
targets = torch.randn(64, 3)

func_model, params = make_functional(model)

def compute_loss(params, data, targets):
    preds = func_model(params, data)
    return torch.mean((preds - targets) ** 2)

per_sample_grads = vmap(compute_loss, (None, 0, 0))(params, data, targets)

If you're making an ensemble of models, you may find combine_state_for_ensemble useful.

Debugging

functorch._C.dump_tensor: Dumps dispatch keys on stack functorch._C._set_vmap_fallback_warning_enabled(False) if the vmap warning spam bothers you.

Future Plans

In the end state, we'd like to upstream this into PyTorch once we iron out the design details. To figure out the details, we need your help -- please send us your use cases by starting a conversation in the issue tracker or try out the prototype.

License

Functorch has a BSD-style license, as found in the LICENSE file.

Citing functorch

If you use functorch in your publication, please cite it by using the following BibTeX entry.

@Misc{functorch2021,
  author =       {Horace He, Richard Zou},
  title =        {functorch: JAX-like composable function transforms for PyTorch},
  howpublished = {\url{https://github.com/facebookresearch/functorch}},
  year =         {2021}
}
Comments
  • ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv

    ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv

    Hi All,

    I was running an older version of PyTorch ( - built from source) with FuncTorch ( - built from source), and somehow I've broken the older version of functorch. When I import functorch I get the following error,

    import functorch
    #returns ImportError: ~/.local/lib/python3.9/site-packages/functorch/_C.so: undefined symbol: _ZNK3c1010TensorImpl16sym_sizes_customEv
    

    The version I had of functorch was 0.2.0a0+9d6ee76, is there a way to perhaps re-install to fix this ImportError? I do have the latest version of PyTorch/FuncTorch in a separate conda environment but I wanted to check how it compares to the older version in this 'older' conda environment PyTorch/Functorch were versions ,1.12.0a0+git7c2103a and 0.2.0a0+9d6ee76 respectively.

    Is there a way to download a specific version of functorch with https://github.com/pytorch/functorch.git ? Or another way to fix this issue?

    opened by AlphaBetaGamma96 24
  • Hessian (w.r.t inputs) calculation in PyTorch differs from FuncTorch

    Hessian (w.r.t inputs) calculation in PyTorch differs from FuncTorch

    Hi All,

    I've been trying to calculate the Hessian of the output of my network with respect to its inputs within FuncTorch. I had a version within PyTorch that supports batches, however, they seem to disagree with each other and I have no idea why they don't give the same results. Something is clearly wrong, I know my PyTorch version is right so either there's an issue in my version of FuncTorch or I've implemented it wrong in FuncTorch.

    Also, how can I use the has_aux flag in jacrev to return the jacobian from the first jacrev so I don't have to repeat the jacobian calculation?

    The only problem with my example is that it uses torch.linalg.slogdet and from what I remember FuncTorch can't vmap over .item(). I do have my own fork of pytorch where I edited the backward to remove the .item() call so it works with vmap. Although, it's not the greatest implementation as I just set it to the default nonsingular_case_backward like so,

    Tensor slogdet_backward(const Tensor& grad_logabsdet,
                            const Tensor& self,
                            const Tensor& signdet, const Tensor& logabsdet) {
      auto singular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
        Tensor u, sigma, vh;
        std::tie(u, sigma, vh) = at::linalg_svd(self, false);
        Tensor v = vh.mH();
        // sigma has all non-negative entries (also with at least one zero entry)
        // so logabsdet = \sum log(abs(sigma))
        // but det = 0, so backward logabsdet = \sum log(sigma)
        auto gsigma = grad_logabsdet.unsqueeze(-1).div(sigma);
        return svd_backward({}, gsigma, {}, u, sigma, vh);
      };
    
      auto nonsingular_case_backward = [&](const Tensor& grad_logabsdet, const Tensor& self) -> Tensor {
        // TODO: replace self.inverse with linalg_inverse
        return unsqueeze_multiple(grad_logabsdet, {-1, -2}, self.dim()) * self.inverse().mH();
      };
    
      auto nonsingular = nonsingular_case_backward(grad_logabsdet, self);
      return nonsingular;
    }
    

    My 'minimal' reproducible script is below with the output shown below that. It computes the Laplacian via a PyTorch method and via FuncTorch for a single sample of size [A,1] where A is the number of input nodes to the network.

    import torch
    import torch.nn as nn
    from torch import Tensor
    import functorch
    from functorch import jacrev, jacfwd, hessian, make_functional, vmap
    import time 
    
    _ = torch.manual_seed(0)
    
    print("PyTorch version:   ", torch.__version__)
    print("CUDA version:      ", torch.version.cuda)
    print("FuncTorch version: ", functorch.__version__)
    
    def sync_time() -> float:
      torch.cuda.synchronize()
      return time.perf_counter()
    
    B=1 #batch
    A=3 #input nodes
    
    device=torch.device("cuda")
    
    class model(nn.Module):
    
      def __init__(self, num_inputs, num_hidden):
        super(model, self).__init__()
        
        self.num_inputs=num_inputs
        self.func = nn.Tanh()
        
        self.fc1 = nn.Linear(2, num_hidden)
        self.fc2 = nn.Linear(num_hidden, num_inputs)
      
      def forward(self, x):
        """
        Takes x in [B,A,1] and maps it to sign/logabsdet value in Tuple([B,], [B,])
        """
        
        idx=len(x.shape)
        rep=[1 for _ in range(idx)]
        rep[-2] = self.num_inputs
        g = x.mean(dim=(idx-2), keepdim=True).repeat(*rep)
        f = torch.cat((x,g), dim=-1)
    
        h = self.func(self.fc1(f))
        
        mat = self.fc2(h)
        sgn, logabs = torch.linalg.slogdet(mat)
        return sgn, logabs
    
    net = model(A, 64)
    net = net.to(device)
    
    fnet, params = make_functional(net)
    
    def logabs(params, x):
      _, logabs = fnet(params, x)
      #print("functorch logabs: ",logabs)
      return logabs
    
    
    def kinetic_pytorch(xs: Tensor) -> Tensor:
      """Method to calculate the local kinetic energy values of a netork function, f, for samples, x.
      The values calculated here are 1/f d2f/dx2 which is equivalent to d2log(|f|)/dx2 + (dlog(|f|)/dx)^2
      within the log-domain (rather than the linear-domain).
    
      :param xs: The input positions of the many-body particles
      :type xs: class: `torch.Tensor`
      """
      xis = [xi.requires_grad_() for xi in xs.flatten(start_dim=1).t()]
      xs_flat = torch.stack(xis, dim=1)
    
      _, ys = net(xs_flat.view_as(xs))
      #print("pytorch logabs: ",ys)
      ones = torch.ones_like(ys)
    
      #df_dx calculation
      (dy_dxs, ) = torch.autograd.grad(ys, xs_flat, ones, retain_graph=True, create_graph=True)
    
    
      #d2f_dx2 calculation (diagonal only)
      lay_ys = sum(torch.autograd.grad(dy_dxi, xi, ones, retain_graph=True, create_graph=False)[0] \
                    for xi, dy_dxi in zip(xis, (dy_dxs[..., i] for i in range(len(xis))))
      )
      #print("(PyTorch): ",lay_ys, dy_dxs)
      
      ek_local_per_walker = -0.5 * (lay_ys + dy_dxs.pow(2).sum(-1)) #move const out of loop?
      return ek_local_per_walker
      
    jacjaclogabs = jacrev(jacrev(logabs, argnums=1), argnums=1)
    jaclogabs = jacrev(logabs, argnums=1)
      
    def kinetic_functorch(params, x):
      d2f_dx2 = vmap(jacjaclogabs, in_dims=(None, 0))(params, x)
      df_dx = vmap(jaclogabs, in_dims=(None, 0))(params, x)
      #print("(FuncTorch): ", d2f_dx2.squeeze(-3).squeeze(-1).diagonal(-2,-1).sum(-1), df_dx)
      #remove the trailing 1's so it's an A by A matrix 
      return -0.5 * d2f_dx2.squeeze(-3).squeeze(-1).diagonal(-2,-1).sum(-1) + df_dx.squeeze(-1).pow(2).sum(-1)
    
    x = torch.randn(B,A,1,device=device) #input Tensor 
    
    print("\nd2f/dx2, df/dx: ")
    t1=sync_time()
    kin_pt = kinetic_pytorch(x)
    t2=sync_time()
    t3=sync_time()
    kin_ft = kinetic_functorch(params, x)
    t4=sync_time()
    
    print("\nWalltime: ")
    print("PyTorch:   ",t2-t1)
    print("FuncTorch: ",t4-t3, "\n")
    
    print("Results: ")
    print("PyTorch: ",kin_pt)
    print("FuncTorch: ",kin_ft)
    

    This script returns

    PyTorch version:    1.12.0a0+git7c2103a
    CUDA version:       11.6
    FuncTorch version:  0.2.0a0+9d6ee76
    
    d2f/dx2, df/dx: 
    
    Walltime: 
    PyTorch:    0.4822753759999614
    FuncTorch:  0.004898710998531897 
    
    Results: 
    PyTorch:  tensor([1.3737], device='cuda:0', grad_fn=<MulBackward0>)    # should be the same values
    FuncTorch:  tensor([7.8411], device='cuda:0', grad_fn=<AddBackward0>) # the jacobian matches, but hessian does not
    

    Thanks for the help in advance! :)

    opened by AlphaBetaGamma96 18
  • Semantic discrepancy on requires_grad after compiling Tensor.detach

    Semantic discrepancy on requires_grad after compiling Tensor.detach

    Reproduce:

    import torch
    from functorch.compile import aot_function
    
    def fn(x):
        return x.detach()
    
    aot_fn = aot_function(fn, fw_compiler=lambda fx_module, _: fx_module)
    
    x = torch.randn(1, requires_grad=True)
    ref = fn(x)
    res = aot_fn(x)
    
    assert(ref.requires_grad == res.requires_grad)
    

    PyTorch version: 1.13.0.dev20220929+cu116

    Not sure if this is related to #376.

    opened by sangongs 14
  • add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL

    add batching rule for block_diag, kill DECOMPOSE_FUNCTIONAL

    Companion core PR: https://github.com/pytorch/pytorch/pull/77716

    The above PR makes block_diag composite compliant, and this PR adds a batching rule for it.

    Those two changes together should let us fully remove the DECOMPOSE_FUNCTIONAL macro, which was preventing me from moving the Functionalize dispatch key below FuncTorchBatched (which I want to do as part of XX, in order to properly get functionalization working with LTC/XLA).

    cla signed 
    opened by bdhirsh 13
  • svd-related op regression in functorch

    svd-related op regression in functorch

    https://github.com/pytorch/pytorch/pull/69827 and https://github.com/pytorch/pytorch/pull/70253 caused svd-related tests in functorch to fail:

    • https://app.circleci.com/pipelines/github/pytorch/functorch/1277/workflows/5aaf2c43-6c6a-4ab1-94f7-e0493b8049ff/jobs/7659

    The main problem seems to be that the backward pass uses in-place operations that are incompatible with vmap (aka Composite Compliance problems). There are some other failures that seem to be because some other operations are not Composite Compliant but somehow these weren't a problem previously.

    opened by zou3519 12
  • Installing functorch breaks torchaudio

    Installing functorch breaks torchaudio

    I'm following along with this colab from the functorch installation docs.

    After installing and restarting, when I try to import torchaudio, the runtime crashes. At first, I got this error:

    OSError: /usr/local/lib/python3.7/dist-packages/torchaudio/lib/libtorchaudio.so: undefined symbol: _ZN2at4_ops7resize_4callERKNS_6TensorEN3c108ArrayRefIlEENS5_8optionalINS5_12MemoryFormatEEE
    

    Now, I'm just getting the runtime crashing with no visible error.

    I know functorch was merged into pytorch proper, but I don't see any instructions about how to use it from there. Would that fix the issue? If so, should the main docs be updated?

    actionable 
    opened by dellis23 11
  • functorch doesn't work in debug mode

    functorch doesn't work in debug mode

    It's that autograd assert that we run into often:

    import torch
    from functorch import make_fx
    from functorch.compile import nnc_jit
    
    
    def f(x, y):
        return torch.broadcast_tensors(x, y)
    
    
    inp1 = torch.rand(())
    inp2 = torch.rand(3)
    
    print(f(inp1, inp2))  # without nnc compile everything works fine
    
    print(make_fx(f)(inp1, inp2))  # fails
    print(nnc_jit(f)(inp1, inp2))
    # RuntimeError: self__storage_saved.value().is_alias_of(result.storage())INTERNAL ASSERT FAILED at "autograd/generated/VariableType_3.cpp":3899, please report a bug to PyTorch.
    

    cc @albanD @soulitzer what's the chance we can add an option to turn these off? They've been more harmful (e.g. prevent debugging in debug mode) than useful for us.

    opened by zou3519 11
  • Index put vmap internal assert

    Index put vmap internal assert

    import torch
    from functorch import vmap
    self = torch.randn(4, 1, 1).cuda()
    idx = (torch.tensor([0]).cuda(),)
    value = torch.randn(1, 1).cuda()
    
    def foo(x):
        return x.index_put_(idx, value, accumulate=True)
    
    vmap(foo)(self)
    
    RuntimeError: linearIndex.numel()*sliceSize*nElemBefore == value.numel()INTERNAL ASSERT FAILED at "/raid/rzou/pt/debug-cuda/aten/src/ATen/native/cuda/Indexing.cu":249, please report a bug to PyTorch. number of flattened indices did not match number of elements in the value tensor41
    
    actionable 
    opened by zou3519 11
  • Add flake8 pre commit hook script

    Add flake8 pre commit hook script

    PyTorch's pre commit hooks, scripts that are called are in tools

    Really I selfishly just want the flake8 ones so I don't have to remember to run it against my changes each time. We could also get the clang tidy info while we're in there

    actionable 
    opened by samdow 10
  • Batching rule not implemented for aten::item.

    Batching rule not implemented for aten::item.

    Hey, I would like to use functorch.vmap in a custom PyTorch activation function (the gradients are not needed, because the backward-pass is calculated differently). During the computation of the activation function, I do a lookup in a tensor X using a tensor Y.item() call, similar to the small dummy code below.

    Unfortunately I get the error message: RuntimeError: Batching rule not implemented for aten::item. We could not generate a fallback.

    Is it not possible to do an item() call in a vmap function or is something else wrong? Thanks a lot!

    import torch
    from functorch import vmap
    
    device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
    
    sum = torch.zeros([100, 10], dtype=torch.int32).to(device)
    lookup = torch.randint(100, (20, 1000, 10)).to(device)
    input_tensor = torch.randint(1000, (100, 20)).to(device)
    
    def test_fun(sum, input_tensor):
      for j in range(20):
        for i in range(10):
          sum[i] += lookup[j, input_tensor[j].item(), i]
      return sum
    
    # non-vectorized version
    for i in range(100):
      test_fun(sum[i], input_tensor[I])
    
    # vectorized version throws error
    test_fun_vec = vmap(test_fun)
    test_fun_vec(sum, input_tensor)
    
    opened by hallojs 10
  • torch.atleast_1d batching rule implementation

    torch.atleast_1d batching rule implementation

    Hi functorch devs! I'm filing this issue because my code prints the following warning:

    UserWarning: There is a performance drop because we have not yet implemented the batching rule for aten::atleast_1d. Please file us an issue on GitHub so that we can prioritize its implementation. (Triggered internally at  /tmp/pip-req-build-ytawxmfk/functorch/csrc/BatchedFallback.cpp:106.)
    

    Why Am I Using atleast_1d ?

    I'm subclassing torch.Tensor because my code needs to be able to add some extra data to that class (I'm integrating PyTorch's AD system with another AD system to be able to call torch functions from inside a PDE solve, which is why I also inherit from a class called OverloadedType), which is named _block_variable; e.g. the subclass looks like

    class MyTensor(torch.Tensor, OverloadedType):
        _block_variable = None
    
        @staticmethod
        def __new__(cls, x, *args, **kwargs):
            return super().__new__(cls, x, *args, **kwargs)
    
        def __init__(self, x, block_var=None):
            super(OverloadedType, self).__init__()
            self._block_variable = block_var or BlockVariable(self)
            
    
        def to(self, *args, **kwargs):
            new = Tensor([])
            tmp = super(torch.Tensor, self).to(*args, **kwargs)
            new.data = tmp.data
            new.requires_grad = tmp.requires_grad
            new._block_variable = self._block_variable
            return new
    
         ... #some subclass-specific methods etc
    

    This causes problems when I have code that does stuff like torch.tensor([torch.trace(x), torch.trace(x @ x)]) where x is a square MyTensor; the torch.tensor() call raises an exception related to taking the __len__ of a 0-dimentional tensor (the scalar traces). So instead, I do torch.cat([torch.atleast_1d(torch.trace(x)), torch.atleast_1d(torch.trace(x @ x))]), which works. However, this function is functorch.vmap-ed, which triggers the performance warning. It would be great if I could either get the naive implementation (using torch.tensor instead of torch.cat) to work, or if a batch rule for atleast_1d() were to be implemented.

    Thank you for any help you can provide!

    opened by DiffeoInvariant 10
  • Add pytorch 1.13.1 compatibility

    Add pytorch 1.13.1 compatibility

    When torch 1.13.1 is installed on my machine, and I try to pip install a package with functorch as a dependency, the functorch 1.13.0 package is found and requires downloading of torch 1.13.0. Since functorch >= 1.13 is a dummy package, I'd guess all that's needed is updating pypi

    opened by scal444 1
  • batching over model parameters

    batching over model parameters

    I have a use-case for functorch. I would like to check possible iterations of model parameters in a very efficient way (I want to eliminate the loop). Here's an example code for a simplified case I got it working:

    linear = torch.nn.Linear(10,2)
    default_weight = linear.weight.data
    sample_input = torch.rand(3,10)
    sample_add = torch.rand_like(default_weight)
    def interpolate_weights(alpha):
        with torch.no_grad():
            res_weight = torch.nn.Parameter(default_weight + alpha*sample_add)
            linear.weight = res_weight
            return linear(sample_input)
    

    now I could do for alpha in np.np.linspace(0.0, 1.0, 100) but I want to vectorise this loop since my code is prohibitively slow. Is functorch here applicable? Executing:

    alphas = torch.linspace(0.0, 1.0, 100)
    vmap(interpolate_weights)(alphas)
    

    works, but how to do something similar for a simple resnet does not work. I've tried using load_state_dict but that's not working:

    from torchvision import models
    model_resnet = models.resnet18(pretrained=True)
    
    named_params = list(model_resnet.named_parameters())
    named_params_data = [(n,p.data.clone()) for (n,p) in named_params]
    
    sample_data = torch.rand(10,3,224,244)
    
    def test_resnet(new_params):
        def interpolate(alpha):
            with torch.no_grad():
                p_dict = {name:(old + alpha*new_params[i]) for i,(name, old) in enumerate(named_params_data)}
                model_resnet.load_state_dict(p_dict, strict=False)
                out = model_resnet(sample_data)
                return out
        return interpolate
    
    rand_tensor = [torch.rand_like(p) for n,p in named_params_data]
    
    to_vamp_resnet = test_thing(rand_tensor)
    vmap(to_vamp_resnet)(alphas)
    

    results in:

    While copying the parameter named "fc.bias", whose dimensions in the model are torch.Size([1000]) and whose dimensions in the checkpoint are torch.Size([1000]), an exception occurred : ('vmap: inplace arithmetic(self, *extra_args) is not possible because there exists a Tensorotherin extra_args that has more elements thanself. This happened due tootherbeing vmapped over butselfnot being vmapped over in a vmap. Please try to use out-of-place operators instead of inplace arithmetic. If said operator is being called inside the PyTorch framework, please file a bug report instead.',).

    opened by LeanderK 2
  • Make vmap tests use dtype `any_one`

    Make vmap tests use dtype `any_one`

    In #1069, @kshitij12345 smartly pointed out that it's disturbing that these batch rules aren't caught by test_op_has_batch_rule. From looking at it, the bitwise ops in particular aren't being tested because the only allowed_dtype is torch.float

    Steps

    1. First, please update both test_vmap and test_op_has_batch_rule to have their allowed_dtypes (in the @ops decorator) be OpDTypes.any_one instead of torch.float32
    2. We expect this to lead to new failures. Please update the corresponding xfail list for the test. i. In the case of test_op_has_batch_rule, if the failure looks to occur on an in-place function, please try first to only add it the inplace_failures list. If this does not work, you can xfail it
    actionable 
    opened by samdow 0
  • [testing] Insufficient coverage in test suite

    [testing] Insufficient coverage in test suite

    In functorch test suite, we use sample_inputs to get samples from an OpInfo. The problem is that sample_inputs may or may not cover all the case/overloads for an operator. I think we should use reference_inputs which super set of sample_inputs and more comprehensive. (Though this will increase the test times).

    Switching sample_inputs to reference_inputs leads to bunch of failure for test_op_has_batch_rule including the ones mentioned in https://github.com/pytorch/functorch/issues/1080 https://github.com/pytorch/functorch/issues/1069

    Refer to https://github.com/pytorch/pytorch/pull/91355 for failures.

    cc: @zou3519

    opened by kshitij12345 3
  • vmap + GRU

    vmap + GRU

    Hi everyone, I was trying to retrieve per-sample gradients following the functorch documentation for a GRU-like model, but i get the following error:

    Traceback (most recent call last):
      File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 51, in <module>
        ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 362, in wrapped
        return _flat_vmap(
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 35, in fn
        return f(*args, **kwargs)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 489, in _flat_vmap
        batched_outputs = func(*batched_inputs, **kwargs)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\eager_transforms.py", line 1241, in wrapper
        results = grad_and_value(func, argnums, has_aux=has_aux)(*args, **kwargs)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\vmap.py", line 35, in fn
        return f(*args, **kwargs)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\eager_transforms.py", line 1111, in wrapper
        output = func(*args, **kwargs)
      File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 26, in compute_loss_stateless_model
        prediction = fmodel(params, buffers, sample.unsqueeze(1), state.unsqueeze(1))
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
        return forward_call(*input, **kwargs)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\functorch\_src\make_functional.py", line 282, in forward
        return self.stateless_model(*args, **kwargs)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
        return forward_call(*input, **kwargs)
      File "c:/Users/Ospite/Desktop/temp/funct/examples/functorch_gru.py", line 20, in forward
        x, _ = self.recurrent(x, hx)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\module.py", line 1190, in _call_impl
        return forward_call(*input, **kwargs)
      File "C:\Users\Ospite\Desktop\temp\funct\.venv\lib\site-packages\torch\nn\modules\rnn.py", line 955, in forward
        result = _VF.gru(input, hx, self._flat_weights, self.bias, self.num_layers,
    RuntimeError: Batching rule not implemented for aten::unsafe_split.Tensor. We could not generate a fallback.
    

    Vanilla RNN works correctly. The code i've used is the following:

    from functools import partial
    from typing import Type, Union
    
    import torch
    from functorch import grad, make_functional_with_buffers, vmap
    
    
    class Recurrent(torch.nn.Module):
        def __init__(
            self,
            recurrent_layer: Union[Type[torch.nn.GRU], Type[torch.nn.RNN]],
            input_size: int,
            hidden_size: int,
            output_size: int,
        ) -> None:
            super().__init__()
            self.recurrent = recurrent_layer(input_size=input_size, hidden_size=hidden_size, batch_first=False)
            self.fc = torch.nn.Linear(hidden_size, output_size)
    
        def forward(self, x: torch.Tensor, hx: torch.Tensor) -> torch.Tensor:
            x, _ = self.recurrent(x, hx)
            x = self.fc(torch.relu(x))
            return x
    
    
    def compute_loss_stateless_model(fmodel, params, buffers, sample, target, state):
        prediction = fmodel(params, buffers, sample.unsqueeze(1), state.unsqueeze(1))
        loss = torch.nn.functional.mse_loss(prediction, target.unsqueeze(1))
        return loss
    
    
    if __name__ == "__main__":
        T, B, D, H, O = 128, 64, 64, 256, 1
        x = torch.rand(T, B, D)
        t = torch.ones(T, B, O)
        hx = torch.zeros(1, B, H)
        gru = Recurrent(torch.nn.GRU, D, H, O)
        rnn = Recurrent(torch.nn.RNN, D, H, O)
    
        # functional RNN + vmap
        frnn, params, buffers = make_functional_with_buffers(rnn)
        ft_compute_grad = grad(partial(compute_loss_stateless_model, frnn))
        ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1, 1))
        ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
        for g in ft_sample_grads:
            print(g.shape)
    
        # functional GRU + vmap
        fgru, params, buffers = make_functional_with_buffers(gru)
        ft_compute_grad = grad(partial(compute_loss_stateless_model, fgru))
        ft_compute_sample_grad = vmap(ft_compute_grad, in_dims=(None, None, 1, 1, 1))
        ft_sample_grads = ft_compute_sample_grad(params, buffers, x, t, hx)
        for g in ft_sample_grads:
            print(g.shape)
    

    The collected environment is the following:

    PyTorch version: 1.13.0+cpu
    Is debug build: False
    CUDA used to build PyTorch: Could not collect
    ROCM used to build PyTorch: N/A
    
    OS: Microsoft Windows 10 Pro
    GCC version: Could not collect
    Clang version: Could not collect
    CMake version: Could not collect
    Libc version: N/A
    
    Python version: 3.8.10 (tags/v3.8.10:3d8993a, May  3 2021, 11:48:03) [MSC v.1928 64 bit (AMD64)] (64-bit runtime)
    Python platform: Windows-10-10.0.19044-SP0
    Is CUDA available: False
    CUDA runtime version: Could not collect
    CUDA_MODULE_LOADING set to: N/A
    GPU models and configuration: GPU 0: NVIDIA GeForce RTX 2070 SUPER
    Nvidia driver version: 516.94
    cuDNN version: Could not collect
    HIP runtime version: N/A
    MIOpen runtime version: N/A
    Is XNNPACK available: True
    
    Versions of relevant libraries:
    [pip3] functorch==1.13.0
    [pip3] mypy==0.931
    [pip3] mypy-extensions==0.4.3
    [pip3] numpy==1.23.5
    [pip3] pytorch-lightning==1.8.3.post1
    [pip3] torch==1.13.0
    [pip3] torchmetrics==0.11.0
    [conda] Could not collect
    

    Thank you, Federico

    opened by belerico 0
  • Add vmap support for PyTorch operators

    Add vmap support for PyTorch operators

    We're looking for more motivated open-source developers to help build out functorch (and PyTorch, since functorch is now just a part of PyTorch). Below is a selection of good first issues.

    • [x] https://github.com/pytorch/pytorch/issues/91174
    • [ ] https://github.com/pytorch/pytorch/issues/91175
    • [ ] https://github.com/pytorch/pytorch/issues/91176
    • [ ] https://github.com/pytorch/pytorch/issues/91177
    • [ ] https://github.com/pytorch/pytorch/issues/91402
    • [ ] https://github.com/pytorch/pytorch/issues/91403
    • [x] https://github.com/pytorch/pytorch/issues/91404
    • [ ] https://github.com/pytorch/pytorch/issues/91415
    • [ ] https://github.com/pytorch/pytorch/issues/91700

    In general there's a high barrier to developing PyTorch and/or functorch. We've collected topics and information over at the PyTorch Developer Wiki

    good first issue 
    opened by zou3519 2
Releases(v1.13.0)
Owner
Facebook Research
Facebook Research
CLOOB training (JAX) and inference (JAX and PyTorch)

cloob-training Pretrained models There are two pretrained CLOOB models in this repo at the moment, a 16 epoch and a 32 epoch ViT-B/16 checkpoint train

Katherine Crowson 64 Nov 27, 2022
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 21.3k Jan 1, 2023
Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more

JAX: Autograd and XLA Quickstart | Transformations | Install guide | Neural net libraries | Change logs | Reference docs | Code search News: JAX tops

Google 11.4k Feb 13, 2021
Composable transformations of Python+NumPy programsComposable transformations of Python+NumPy programs

Chex Chex is a library of utilities for helping to write reliable JAX code. This includes utils to help: Instrument your code (e.g. assertions) Debug

DeepMind 506 Jan 8, 2023
GAN JAX - A toy project to generate images from GANs with JAX

GAN JAX - A toy project to generate images from GANs with JAX This project aims to bring the power of JAX, a Python framework developped by Google and

Valentin Goldité 14 Nov 29, 2022
Mini-hmc-jax - A simple implementation of Hamiltonian Monte Carlo in JAX

mini-hmc-jax This is a simple implementation of Hamiltonian Monte Carlo in JAX t

Martin Marek 6 Mar 3, 2022
TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

TorchGeo is a PyTorch domain library, similar to torchvision, that provides datasets, transforms, samplers, and pre-trained models specific to geospatial data.

Microsoft 1.3k Dec 30, 2022
Datasets, Transforms and Models specific to Computer Vision

torchvision The torchvision package consists of popular datasets, model architectures, and common image transformations for computer vision. Installat

null 13.1k Jan 2, 2023
Unofficial implementation of Google's FNet: Mixing Tokens with Fourier Transforms

FNet: Mixing Tokens with Fourier Transforms Pytorch implementation of Fnet : Mixing Tokens with Fourier Transforms. Citation: @misc{leethorp2021fnet,

Rishikesh (ऋषिकेश) 218 Jan 5, 2023
Image data augmentation scheduler for albumentations transforms

albu_scheduler Scheduler for albumentations transforms based on PyTorch schedulers interface Usage TransformMultiStepScheduler import albumentations a

null 19 Aug 4, 2021
Progressive Coordinate Transforms for Monocular 3D Object Detection

Progressive Coordinate Transforms for Monocular 3D Object Detection This repository is the official implementation of PCT. Introduction In this paper,

null 58 Nov 6, 2022
RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching

RAFT-Stereo: Multilevel Recurrent Field Transforms for Stereo Matching This repository contains the source code for our paper: RAFT-Stereo: Multilevel

Princeton Vision & Learning Lab 328 Jan 9, 2023
Image Processing, Image Smoothing, Edge Detection and Transforms

opevcvdl-hw1 This project uses openCV and Qt to achieve the requirements. Version Python 3.7 opencv-contrib-python 3.4.2.17 Matplotlib 3.1.1 pyqt5 5.1

Kenny Cheng 3 Aug 17, 2022
Justmagic - Use a function as a method with this mystic script, like in Nim

justmagic Use a function as a method with this mystic script, like in Nim. Just

witer33 8 Oct 8, 2022
Adaptive Prototype Learning and Allocation for Few-Shot Segmentation (CVPR 2021)

ASGNet The code is for the paper "Adaptive Prototype Learning and Allocation for Few-Shot Segmentation" (accepted to CVPR 2021) [arxiv] Overview data/

Gen Li 91 Dec 23, 2022
Implementation of the paper "Self-Promoted Prototype Refinement for Few-Shot Class-Incremental Learning"

Self-Promoted Prototype Refinement for Few-Shot Class-Incremental Learning This is the implementation of the paper "Self-Promoted Prototype Refinement

Kai Zhu 78 Dec 2, 2022
Prototype for Baby Action Detection and Classification

Baby Action Detection Table of Contents About Install Run Predictions Demo About An attempt to harness the power of Deep Learning to come up with a so

Shreyas K 30 Dec 16, 2022
Normal Learning in Videos with Attention Prototype Network

Codes_APN Official codes of CVPR21 paper: Normal Learning in Videos with Attention Prototype Network (https://arxiv.org/abs/2108.11055) Overview of ou

null 11 Dec 13, 2022
Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation

Audio-Visual Generalized Few-Shot Learning with Prototype-Based Co-Adaptation The code repository for "Audio-Visual Generalized Few-Shot Learning with

Kaiaicy 3 Jun 27, 2022