Differentiable ODE solvers with full GPU support and O(1)-memory backpropagation.

Overview

PyTorch Implementation of Differentiable ODE Solvers

This library provides ordinary differential equation (ODE) solvers implemented in PyTorch. Backpropagation through ODE solutions is supported using the adjoint method for constant memory cost. For usage of ODE solvers in deep learning applications, see reference [1].

As the solvers are implemented in PyTorch, algorithms in this repository are fully supported to run on the GPU.

Installation

To install latest stable version:

pip install torchdiffeq

To install latest on GitHub:

pip install git+https://github.com/rtqichen/torchdiffeq

Examples

Examples are placed in the examples directory.

We encourage those who are interested in using this library to take a look at examples/ode_demo.py for understanding how to use torchdiffeq to fit a simple spiral ODE.

ODE Demo

Basic usage

This library provides one main interface odeint which contains general-purpose algorithms for solving initial value problems (IVP), with gradients implemented for all main arguments. An initial value problem consists of an ODE and an initial value,

dy/dt = f(t, y)    y(t_0) = y_0.

The goal of an ODE solver is to find a continuous trajectory satisfying the ODE that passes through the initial condition.

To solve an IVP using the default solver:

from torchdiffeq import odeint

odeint(func, y0, t)

where func is any callable implementing the ordinary differential equation f(t, x), y0 is an any-D Tensor representing the initial values, and t is a 1-D Tensor containing the evaluation points. The initial time is taken to be t[0].

Backpropagation through odeint goes through the internals of the solver. Note that this is not numerically stable for all solvers (but should probably be fine with the default dopri5 method). Instead, we encourage the use of the adjoint method explained in [1], which will allow solving with as many steps as necessary due to O(1) memory usage.

To use the adjoint method:

from torchdiffeq import odeint_adjoint as odeint

odeint(func, y0, t)

odeint_adjoint simply wraps around odeint, but will use only O(1) memory in exchange for solving an adjoint ODE in the backward call.

The biggest gotcha is that func must be a nn.Module when using the adjoint method. This is used to collect parameters of the differential equation.

Differentiable event handling

We allow terminating an ODE solution based on an event function. Backpropagation through most solvers is supported. For usage of event handling in deep learning applications, see reference [2].

This can be invoked with odeint_event:

from torchdiffeq import odeint_event
odeint_event(func, y0, t0, *, event_fn, reverse_time=False, odeint_interface=odeint, **kwargs)
  • func and y0 are the same as odeint.
  • t0 is a scalar representing the initial time value.
  • event_fn(t, y) returns a tensor, and is a required keyword argument.
  • reverse_time is a boolean specifying whether we should solve in reverse time. Default is False.
  • odeint_interface is one of odeint or odeint_adjoint, specifying whether adjoint mode should be used for differentiating through the ODE solution. Default is odeint.
  • **kwargs: any remaining keyword arguments are passed to odeint_interface.

The solve is terminated at an event time t and state y when an element of event_fn(t, y) is equal to zero. Multiple outputs from event_fn can be used to specify multiple event functions, of which the first to trigger will terminate the solve.

Both the event time and final state are returned from odeint_event, and can be differentiated. Gradients will be backpropagated through the event function.

The numerical precision for the event time is determined by the atol argument.

See example of simulating and differentiating through a bouncing ball in examples/bouncing_ball.py.

Bouncing Ball

Keyword arguments for odeint(_adjoint)

Keyword arguments:

  • rtol Relative tolerance.
  • atol Absolute tolerance.
  • method One of the solvers listed below.
  • options A dictionary of solver-specific options, see the further documentation.

List of ODE Solvers:

Adaptive-step:

  • dopri8 Runge-Kutta of order 8 of Dormand-Prince-Shampine.
  • dopri5 Runge-Kutta of order 5 of Dormand-Prince-Shampine [default].
  • bosh3 Runge-Kutta of order 3 of Bogacki-Shampine.
  • fehlberg2 Runge-Kutta-Fehlberg of order 2.
  • adaptive_heun Runge-Kutta of order 2.

Fixed-step:

  • euler Euler method.
  • midpoint Midpoint method.
  • rk4 Fourth-order Runge-Kutta with 3/8 rule.
  • explicit_adams Explicit Adams-Bashforth.
  • implicit_adams Implicit Adams-Bashforth-Moulton.

Additionally, all solvers available through SciPy are wrapped for use with scipy_solver.

For most problems, good choices are the default dopri5, or to use rk4 with options=dict(step_size=...) set appropriately small. Adjusting the tolerances (adaptive solvers) or step size (fixed solvers), will allow for trade-offs between speed and accuracy.

Frequently Asked Questions

Take a look at our FAQ for frequently asked questions.

Further documentation

For details of the adjoint-specific and solver-specific options, check out the further documentation.

References

Applications of differentiable ODE solvers and event handling are discussed in these two papers:

[1] Ricky T. Q. Chen, Yulia Rubanova, Jesse Bettencourt, David Duvenaud. "Neural Ordinary Differential Equations." Advances in Neural Information Processing Systems. 2018. [arxiv]

[2] Ricky T. Q. Chen, Brandon Amos, Maximilian Nickel. "Learning Neural Event Functions for Ordinary Differential Equations." International Conference on Learning Representations. 2021. [arxiv]


If you found this library useful in your research, please consider citing.

@article{chen2018neuralode,
  title={Neural Ordinary Differential Equations},
  author={Chen, Ricky T. Q. and Rubanova, Yulia and Bettencourt, Jesse and Duvenaud, David},
  journal={Advances in Neural Information Processing Systems},
  year={2018}
}

@article{chen2021eventfn,
  title={Learning Neural Event Functions for Ordinary Differential Equations},
  author={Chen, Ricky T. Q. and Amos, Brandon and Nickel, Maximilian},
  journal={International Conference on Learning Representations},
  year={2021}
}
Comments
  • torch 1.7, discontinuities, and grid_points

    torch 1.7, discontinuities, and grid_points

    Torch 1.7 has now introduced torch.nextafter. This offers a way to remove the ugly eps hack we've been using with grid_points. Specifically replacing point + eps with torch.nextafter(point, inf) and point - eps with torch.nextafter(point, -inf).

    Using this, an ideal API would remove eps entirely, and have only grid_points and d_discontinuities. The latter corresponds to those grid points that we need to perturb.

    Removing eps would introduce a small amount of backward incompatibility, and using nextafter would bump the dependency up to the most recent version of PyTorch.

    Are you interested in making this change? If so can probably offer a PR this weekend; moreover it'd be worth doing before publically announcing the new version.

    opened by patrick-kidger 17
  • Tried to convert the latent_ode code to run on one dimension of the spiral.

    Tried to convert the latent_ode code to run on one dimension of the spiral.

    Hi Ricky ! This could be potentially me doing a bad job at debugging my modifications. But, it could also be an indicator to that there is something funky going on about the time order of the points in the latent ode example. So, please take a note if you get some free time. :-)

    I tried to use only the x-axis of your generated spirals and run the latent_ode on those one dimensional data. And later plot the learned x-axis against time.

    This is the ground truth image against time axis. ground_truth

    This is the post learning plot plotted against time. It kind of looks like the network is learning the spiral the opposite way. I am almost certain that this is not me plotting it backwards because that would not change the behaviour around t = 0.

    vis

    I hope you find some time to look at what's happening. Cheers !

    below is the modified code:

    import argparse
    import logging
    import time
    import numpy as np
    import numpy.random as npr
    import matplotlib
    matplotlib.use('agg')
    import matplotlib.pyplot as plt
    import torch
    import torch.nn as nn
    import torch.optim as optim
    import torch.nn.functional as F
    
    
    parser = argparse.ArgumentParser()
    parser.add_argument('--adjoint', type=eval, default=False)
    parser.add_argument('--visualize', type=eval, default=False)
    parser.add_argument('--niters', type=int, default=2000)
    parser.add_argument('--lr', type=float, default=0.01)
    parser.add_argument('--gpu', type=int, default=0)
    parser.add_argument('--train_dir', type=str, default=None)
    args = parser.parse_args()
    
    if args.adjoint:
        from torchdiffeq import odeint_adjoint as odeint
    else:
        from torchdiffeq import odeint
    
    
    def generate_spiral2d(nspiral=1000,
                          ntotal=500,
                          nsample=100,
                          start=0.,
                          stop=1,  # approximately equal to 6pi
                          noise_std=0.1,
                          a=0.,
                          b=1.,
                          savefig=True):
        """Parametric formula for 2d spiral is `r = a + b * theta`.
    
        Args:
          nspiral: number of spirals, i.e. batch dimension
          ntotal: total number of datapoints per spiral
          nsample: number of sampled datapoints for model fitting per spiral
          start: spiral starting theta value
          stop: spiral ending theta value
          noise_std: observation noise standard deviation
          a, b: parameters of the Archimedean spiral
          savefig: plot the ground truth for sanity check
    
        Returns: 
          Tuple where first element is true trajectory of size (nspiral, ntotal, 2),
          second element is noisy observations of size (nspiral, nsample, 2),
          third element is timestamps of size (ntotal,),
          and fourth element is timestamps of size (nsample,)
        """
    
        # add 1 all timestamps to avoid division by 0
        orig_ts = np.linspace(start, stop, num=ntotal)
        samp_ts = orig_ts[:nsample]
        sampl_ts = []
        # generate clock-wise and counter clock-wise spirals in observation space
        # with two sets of time-invariant latent dynamics
        zs_cw = stop + 1. - orig_ts
        rs_cw = a + b * 50. / zs_cw
        xs, ys = rs_cw * np.cos(zs_cw) - 5., rs_cw * np.sin(zs_cw)
        orig_traj_cw = np.stack((xs, ys), axis=1)
    
        zs_cc = orig_ts
        rw_cc = a + b * zs_cc
        xs, ys = rw_cc * np.cos(zs_cc) + 5., rw_cc * np.sin(zs_cc)
        orig_traj_cc = np.stack((xs, ys), axis=1)
    
        if savefig:
            plt.figure()
            plt.plot(range(orig_traj_cw[:,0].shape[0]), orig_traj_cw[:, 0],  label='clock')
            plt.plot(range(orig_traj_cc[:,0].shape[0]), orig_traj_cc[:, 0],  label='counter clock')
            plt.legend()
            plt.savefig('./ground_truth.png', dpi=500)
            print('Saved ground truth spiral at {}'.format('./ground_truth.png'))
    
        # sample starting timestamps
        orig_trajs = []
        samp_trajs = []
        for _ in range(nspiral):
            # don't sample t0 very near the start or the end
            t0_idx = npr.multinomial(
                1, [1. / (ntotal - 2. * nsample)] * (ntotal - int(2 * nsample)))
            t0_idx = np.argmax(t0_idx) + nsample
            sampl_ts.append(orig_ts[t0_idx:t0_idx+nsample])
            cc = bool(npr.rand() > .5)  # uniformly select rotation
            orig_traj = orig_traj_cc if cc else orig_traj_cw
            orig_trajs.append(orig_traj)
    
            samp_traj = orig_traj[t0_idx:t0_idx + nsample, :].copy()
            samp_traj += npr.randn(*samp_traj.shape) * noise_std
            samp_trajs.append(samp_traj)
        sampl_ts = np.array(sampl_ts)
        # batching for sample trajectories is good for RNN; batching for original
        # trajectories only for ease of indexing
        orig_trajs = np.stack(orig_trajs, axis=0)
        samp_trajs = np.stack(samp_trajs, axis=0)
    
        return orig_trajs[:,:,:1], samp_trajs[:,:,:1], orig_ts, samp_ts, sampl_ts
    
    
    class LatentODEfunc(nn.Module):
    
        def __init__(self, latent_dim=4, nhidden=20):
            super(LatentODEfunc, self).__init__()
            self.elu = nn.ELU(inplace=True)
            self.fc1 = nn.Linear(latent_dim, nhidden)
            self.fc2 = nn.Linear(nhidden, nhidden)
            self.fc3 = nn.Linear(nhidden, latent_dim)
            self.nfe = 0
    
        def forward(self, t, x):
            self.nfe += 1
            out = self.fc1(x)
            out = self.elu(out)
            out = self.fc2(out)
            out = self.elu(out)
            out = self.fc3(out)
            return out
    
    
    class RecognitionRNN(nn.Module):
    
        def __init__(self, latent_dim=4, obs_dim=2, nhidden=25, nbatch=1):
            super(RecognitionRNN, self).__init__()
            self.nhidden = nhidden
            self.nbatch = nbatch
            self.i2h = nn.Linear(obs_dim + nhidden, nhidden)
            self.h2o = nn.Linear(nhidden, latent_dim * 2)
    
        def forward(self, x, h):
            #print(x.size())
            combined = torch.cat((x, h), dim = 1)
            h = torch.tanh(self.i2h(combined))
            out = self.h2o(h)
            return out, h
    
        def initHidden(self):
            return torch.zeros(self.nbatch, self.nhidden)
    
    
    class Decoder(nn.Module):
    
        def __init__(self, latent_dim=4, obs_dim=2, nhidden=20):
            super(Decoder, self).__init__()
            self.relu = nn.ReLU(inplace=True)
            self.fc1 = nn.Linear(latent_dim, nhidden)
            self.fc2 = nn.Linear(nhidden, obs_dim)
    
        def forward(self, z):
            out = self.fc1(z)
            out = self.relu(out)
            out = self.fc2(out)
            return out
    
    
    class RunningAverageMeter(object):
        """Computes and stores the average and current value"""
    
        def __init__(self, momentum=0.99):
            self.momentum = momentum
            self.reset()
    
        def reset(self):
            self.val = None
            self.avg = 0
    
        def update(self, val):
            if self.val is None:
                self.avg = val
            else:
                self.avg = self.avg * self.momentum + val * (1 - self.momentum)
            self.val = val
    
    
    def log_normal_pdf(x, mean, logvar):
        const = torch.from_numpy(np.array([2. * np.pi])).float().to(x.device)
        const = torch.log(const)
        return -.5 * (const + logvar + (x - mean) ** 2. / torch.exp(logvar))
    
    
    def normal_kl(mu1, lv1, mu2, lv2):
        v1 = torch.exp(lv1)
        v2 = torch.exp(lv2)
        lstd1 = lv1 / 2.
        lstd2 = lv2 / 2.
    
        kl = lstd2 - lstd1 + ((v1 + (mu1 - mu2) ** 2.) / (2. * v2)) - .5
        return kl
    
    
    if __name__ == '__main__':
        latent_dim = 4
        nhidden = 20
        rnn_nhidden = 25
        obs_dim = 1
        nspiral = 1000
        start = 0.
        stop = 6 * np.pi
        noise_std = 0.1
        a = 0.
        b = .3
        ntotal = 500
        nsample = 50
        device = torch.device('cuda:' + str(args.gpu)
                              if torch.cuda.is_available() else 'cpu')
        print(device)
        # generate toy spiral data
        orig_trajs, samp_trajs, orig_ts, samp_ts,sampl_ts = generate_spiral2d(
            nspiral=nspiral,
            start=start,
            stop=stop,
            noise_std=noise_std,
            a=a, b=b
        )
        orig_trajs = torch.from_numpy(orig_trajs).float().to(device)
        samp_trajs = torch.from_numpy(samp_trajs).float().to(device)
        samp_ts = torch.from_numpy(samp_ts).float().to(device)
    
        # model
        func = LatentODEfunc(latent_dim, nhidden).to(device)
        rec = RecognitionRNN(latent_dim, obs_dim, rnn_nhidden, nspiral).to(device)
        dec = Decoder(latent_dim, obs_dim, nhidden).to(device)
        params = (list(func.parameters()) + list(dec.parameters()) + list(rec.parameters()))
        optimizer = optim.Adam(params, lr=args.lr)
        loss_meter = RunningAverageMeter()
        
        #out = func(samp_trajs)
        #make_dot(out)
        from torchviz import make_dot
        
        if args.train_dir is not None:
            if not os.path.exists(args.train_dir):
                os.makedirs(args.train_dir)
            ckpt_path = os.path.join(args.train_dir, 'ckpt.pth')
            if os.path.exists(ckpt_path):
                checkpoint = torch.load(ckpt_path)
                func.load_state_dict(checkpoint['func_state_dict'])
                rec.load_state_dict(checkpoint['rec_state_dict'])
                dec.load_state_dict(checkpoint['dec_state_dict'])
                optimizer.load_state_dict(checkpoint['optimizer_state_dict'])
                orig_trajs = checkpoint['orig_trajs']
                samp_trajs = checkpoint['samp_trajs']
                orig_ts = checkpoint['orig_ts']
                samp_ts = checkpoint['samp_ts']
                print('Loaded ckpt from {}'.format(ckpt_path))
    
        try:
            for itr in range(1, args.niters + 1):
                optimizer.zero_grad()
                # backward in time to infer q(z_0)
                h = rec.initHidden().to(device)
                for t in reversed(range(samp_trajs.size(1))):
                    obs = samp_trajs[:, t, :]
                    out, h = rec.forward(obs, h)
                qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
                epsilon = torch.randn(qz0_mean.size()).to(device)
                z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
                #make_dot(z0)
                # forward in time and solve ode for reconstructions
                pred_z = odeint(func, z0, samp_ts).permute(1, 0, 2)
                
                pred_x = dec(pred_z)
                #dot = make_dot(pred_x[0,:])
            
                #print(dot)
                #dot.format = 'png'
                #dot.render("arch")
                # compute loss
                noise_std_ = torch.zeros(pred_x.size()).to(device) + noise_std
                noise_logvar = 2. * torch.log(noise_std_).to(device)
                logpx = log_normal_pdf(
                    samp_trajs, pred_x, noise_logvar).sum(-1).sum(-1)
                pz0_mean = pz0_logvar = torch.zeros(z0.size()).to(device)
                analytic_kl = normal_kl(qz0_mean, qz0_logvar,
                                        pz0_mean, pz0_logvar).sum(-1)
                loss = torch.mean(-logpx + analytic_kl, dim=0)
                loss.backward()
                optimizer.step()
                loss_meter.update(loss.item())
    
                print('Iter: {}, running avg elbo: {:.4f}'.format(itr, -loss_meter.avg))
    
        except KeyboardInterrupt:
            if args.train_dir is not None:
                ckpt_path = os.path.join(args.train_dir, 'ckpt.pth')
                torch.save({
                    'func_state_dict': func.state_dict(),
                    'rec_state_dict': rec.state_dict(),
                    'dec_state_dict': dec.state_dict(),
                    'optimizer_state_dict': optimizer.state_dict(),
                    'orig_trajs': orig_trajs,
                    'samp_trajs': samp_trajs,
                    'orig_ts': orig_ts,
                    'samp_ts': samp_ts,
                }, ckpt_path)
                print('Stored ckpt at {}'.format(ckpt_path))
            print('Training complete after {} iters.'.format(itr))
        if args.train_dir is not None:
            ckpt_path = os.path.join(args.train_dir, 'ckpt.pth')
            torch.save({
                'func_state_dict': func.state_dict(),
                'rec_state_dict': rec.state_dict(),
                'dec_state_dict': dec.state_dict(),
                'optimizer_state_dict': optimizer.state_dict(),
                'orig_trajs': orig_trajs,
                'samp_trajs': samp_trajs,
               'orig_ts': orig_ts,
                'samp_ts': samp_ts,
            }, ckpt_path)
        
        if args.visualize:
            with torch.no_grad():
                # sample from trajectorys' approx. posterior
                h = rec.initHidden().to(device)
                for t in reversed(range(samp_trajs.size(1))):
                    obs = samp_trajs[:, t, :]
                    out, h = rec.forward(obs, h)
                qz0_mean, qz0_logvar = out[:, :latent_dim], out[:, latent_dim:]
                epsilon = torch.randn(qz0_mean.size()).to(device)
                z0 = epsilon * torch.exp(.5 * qz0_logvar) + qz0_mean
                orig_ts = torch.from_numpy(orig_ts).float().to(device)
    
                # take first trajectory for visualization
                z0 = z0[0]
    
                ts_pos = np.linspace(0., 4. * np.pi, num=4000)
                ts_neg = np.linspace(-np.pi, 0., num=2000)[::-1].copy()
                ts_pos = torch.from_numpy(ts_pos).float().to(device)
                ts_neg = torch.from_numpy(ts_neg).float().to(device)
    
                zs_pos = odeint(func, z0, ts_pos)
                zs_neg = odeint(func, z0, ts_neg)
    
                xs_pos = dec(zs_pos)
                xs_neg = torch.flip(dec(zs_neg), dims=[0])
    
            xs_pos = xs_pos.cpu().numpy()
            xs_neg = xs_neg.cpu().numpy()
            orig_traj = orig_trajs[0].cpu().numpy()
            samp_traj = samp_trajs[0].cpu().numpy()
            orig_ts = orig_ts.cpu().numpy()
            ts_pos = ts_pos.cpu().numpy()
            ts_neg = ts_neg.cpu().numpy()
    
            plt.figure()
            plt.plot(orig_ts, orig_traj[:, 0],'g', label='true trajectory')
            #print(orig_ts,samp_ts,ts_pos)
            plt.plot( ts_pos, xs_pos[:, 0],'r',
                     label='learned trajectory (t>0)')
            plt.plot(ts_neg, xs_neg[:, 0][::-1], 'c',
                     label='learned trajectory (t<0)')
            
            plt.scatter(sampl_ts[0,:],samp_traj[:, 0], 
                    label='sampled data', s=3)
            plt.legend()
            plt.savefig('./vis.png', dpi=500)
            print('Saved visualization figure at {}'.format('./vis.png'))
    
    opened by timkartar 16
  • Jumps and Callbacks

    Jumps and Callbacks

    Hey Ricky. Mostly changes as discussed.

    • Deprecated eps and grid_points in favour of step_locations and jump_locations as per #131.
    • #125:
      • Tidied up how adjoint_options is created: rather than being scattered through adjoint.py it's associated with the solver that actually uses the option.
      • Fixed bug with the options for method / adjoint_method not playing nice if method != adjoint_method.
      • Added convenience option for seminorms.
    • Added some (for now minimal) event handling. The framework is there so adding extra events should be straightforward.
    • SciPy solvers now throw an error if you try to use them with tensors that require gradients.
    • Bumped the version number.

    Unrelatedly, have you thought about getting added to https://pytorch.org/ecosystem/? (Not sure how much it actually means.)

    opened by patrick-kidger 14
  • OOM during backward pass on a model with ~600k parameters

    OOM during backward pass on a model with ~600k parameters

    Hey Ricky,

    I'm running out of memory during the backward pass on a 16gb gpu when running the adjoint method with rtol 1e-5, atol 1e-5, and a network with 631058 parameters.

    I'm not sure why this happens given that the augmented_dynamics is within torch.no_grad() and the tensors saved during the forward pass should not be that large.

    Any thoughts on what is happening and how to debug it?

    The model network is a 3d unet (UNet) that goes into a few 3d conv(node_layers).

    Model(
      (prenet_layers): PrenetLayer(
        (initval_layers): Identity()
        (image_layers): Sequential(
          (0): Conv3d(3, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
          (1): UNet(
            (in_norm): GroupNorm(8, 8, eps=1e-05, affine=True)
            (in_act): ReLU(inplace)
            (down1): down(
              (mpconv): Sequential(
                (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
                (1): double_conv(
                  (conv): Sequential(
                    (0): Conv3d(8, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                    (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): LeakyReLU(negative_slope=0.1, inplace)
                    (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                    (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (5): LeakyReLU(negative_slope=0.1, inplace)
                  )
                )
              )
            )
            (down2): down(
              (mpconv): Sequential(
                (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
                (1): double_conv(
                  (conv): Sequential(
                    (0): Conv3d(16, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                    (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): LeakyReLU(negative_slope=0.1, inplace)
                    (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                    (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (5): LeakyReLU(negative_slope=0.1, inplace)
                  )
                )
              )
            )
            (down3): down(
              (mpconv): Sequential(
                (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
                (1): double_conv(
                  (conv): Sequential(
                    (0): Conv3d(32, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): LeakyReLU(negative_slope=0.1, inplace)
                    (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                    (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (5): LeakyReLU(negative_slope=0.1, inplace)
                  )
                )
              )
            )
            (down4): down(
              (mpconv): Sequential(
                (0): MaxPool3d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
                (1): double_conv(
                  (conv): Sequential(
                    (0): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                    (1): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (2): LeakyReLU(negative_slope=0.1, inplace)
                    (3): Conv3d(64, 64, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                    (4): BatchNorm3d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                    (5): LeakyReLU(negative_slope=0.1, inplace)
                  )
                )
              )
            )
            (up0): up(
              (up): Upsample(scale_factor=2.0, mode=trilinear)
              (conv): double_conv(
                (conv): Sequential(
                  (0): Conv3d(128, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                  (1): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (2): LeakyReLU(negative_slope=0.1, inplace)
                  (3): Conv3d(32, 32, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                  (4): BatchNorm3d(32, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (5): LeakyReLU(negative_slope=0.1, inplace)
                )
              )
            )
            (up1): up(
              (up): Upsample(scale_factor=2.0, mode=trilinear)
              (conv): double_conv(
                (conv): Sequential(
                  (0): Conv3d(64, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                  (1): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (2): LeakyReLU(negative_slope=0.1, inplace)
                  (3): Conv3d(16, 16, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                  (4): BatchNorm3d(16, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (5): LeakyReLU(negative_slope=0.1, inplace)
                )
              )
            )
            (up2): up(
              (up): Upsample(scale_factor=2.0, mode=trilinear)
              (conv): double_conv(
                (conv): Sequential(
                  (0): Conv3d(32, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                  (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (2): LeakyReLU(negative_slope=0.1, inplace)
                  (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                  (4): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (5): LeakyReLU(negative_slope=0.1, inplace)
                )
              )
            )
            (up3): up(
              (up): Upsample(scale_factor=2.0, mode=trilinear)
              (conv): double_conv(
                (conv): Sequential(
                  (0): Conv3d(16, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                  (1): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (2): LeakyReLU(negative_slope=0.1, inplace)
                  (3): Conv3d(8, 8, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
                  (4): BatchNorm3d(8, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
                  (5): LeakyReLU(negative_slope=0.1, inplace)
                )
              )
            )
            (out): Conv3d(8, 8, kernel_size=(1, 1, 1), stride=(1, 1, 1))
            (out_norm): GroupNorm(8, 8, eps=1e-05, affine=True)
          )
        )
      )
      (node_layers): Sequential(
        (0): ODEBlock(
          (odefunc): ODEfunc(
            (tanh): Tanh()
            (conv_emb1): ConcatConv3d(
              (_layer): Conv3d(2, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            )
            (norm_emb1): GroupNorm(4, 4, eps=1e-05, affine=True)
            (conv_emb2): ConcatConv3d(
              (_layer): Conv3d(5, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            )
            (norm_emb2): GroupNorm(4, 4, eps=1e-05, affine=True)
            (norm_img_pre): GroupNorm(8, 8, eps=1e-05, affine=True)
            (conv_img): ConcatConv3d(
              (_layer): Conv3d(9, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            )
            (norm_img): GroupNorm(4, 4, eps=1e-05, affine=True)
            (conv1): ConcatConv3d(
              (_layer): Conv3d(9, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            )
            (norm_conv1): GroupNorm(4, 4, eps=1e-05, affine=True)
            (conv2): ConcatConv3d(
              (_layer): Conv3d(5, 4, kernel_size=(3, 3, 3), stride=(1, 1, 1), padding=(1, 1, 1))
            )
            (norm_conv2): GroupNorm(4, 4, eps=1e-05, affine=True)
            (conv_out): ConcatConv3d(
              (_layer): Conv3d(5, 1, kernel_size=(1, 1, 1), stride=(1, 1, 1))
            )
          )
        )
      )
      (postnet_layers): Sequential(
        (0): Identity()
      )
    )
    
    opened by rafaelvalle 14
  • Receptive field in CNN-based architecture, and more about the usage of conv blocks

    Receptive field in CNN-based architecture, and more about the usage of conv blocks

    Thanks for the very interesting paper and the implemetation - fabulous work!

    I just have a question about the usage of CNN blocks in the model. Please correct me if I'm wrong - it seems that a neural ODE with a single conv block will lead to infinite receptive field even for a very short integration time. For discrete ODE-based designs such as Euler Net and Runge-Kutta Net, the receptive fields are all finite and dependent on the number of layers/blocks. If so, (1) it seems that a conv block degenerates into a FC layer applied on the entire (flattened) input, since the concept of "receptive field" no loger holds in this case, and (2) it seems that there's no need to use a larger/deeper model for neural ode if the target is to cover a large, possibly global receptive field - a single block (perhaps together with HyperNet) should be enough for everything. I'm not sure if this assumption still holds for larger models such as ResNet50/ResNet101 on larger datasets, but my intuition is that a single conv-block ODE might be hard to hit on par performance with them (e.g. a comment in #32 about the performance on CIFAR10). So I'm also wondering if you have done any numerical experiments on larger datasets and compare neural ODE with larger, especially deeper models.

    Thanks in advance!

    opened by yluo42 14
  • Increasing divergence and oscillation between steps

    Increasing divergence and oscillation between steps

    Hey @rtqichen!

    While doing experiments on a model, I found out that the L2 norm of the distance between subsequent states is increasing with the relative tolerance set to 0 and different values of the absolute tolerance.

    I understand that that distance is proportionally upper bounded by atol + rtol * norm(state) but I would not expect it to grow because to me it suggests the solution is diverging or oscillating. For example, some state at step T is more similar to a state at step 0 than a state in step (0, T).

    Any thoughts? I'm using the adaptive solver with RK45. Here are some plots including the infinity norm of the state, the L2 norm of the distance between subsequent states and the respective error threshold.

    Absolute tolerance 1e-5, Relative tolerance 0 atol1e-5rtol0

    Absolute tolerance 1e-9, Relative tolerance 0 atol1e-9rtol0

    Absolute tolerance 1e-10, Relative tolerance 0 atol1e-10rtol0

    opened by rafaelvalle 13
  • Reverse integration

    Reverse integration

    Hello,

    Thanks for the package, it is great. I have a question regarding how to perform reverse integration, which is not quite clear for me. Here is an example:

    Assume I have time t=[0, T]. And initial point is z0=f(x), where x is input to the network. Using your package, I can integrate ODE, which goes trough z0 and find zT. Now, given zT, how can I integrate back, to get corresponding z0 (the same, which is used to derive zT)?

    Jurijs

    opened by JurijsNazarovs 8
  • Gradients for Parameters of Coupled ODEs

    Gradients for Parameters of Coupled ODEs

    Hi!

    I need to compute the gradients for parameters of coupled ODEs. My equation is of the form dy/dt = y * x(t), where x(t) is an ODE. x(t) needs to be integrated over the time range before calling it in y(t), and the parameters which I want the gradients are located inside x(t). I saw #129 and tried to follow the steps there, but the parameters are returning nan gradients. Can you please let me know if this is possible using the package.

    opened by skresearcher 8
  • The gradient of odeint_adjoint is zero with multiple GPUs

    The gradient of odeint_adjoint is zero with multiple GPUs

    I found that using exactly the same code, I got the following results:

    1. Single GPU: odeint and odeint_adjoint worked just fine.
    2. Multiple GPU: odeint worked fine but odeint_adjoint always resulted in zero gradient.
    3. Using the adjoint sensitivity in torchdyn, multiple GPUs works fine.

    My pytorch version is 1.5.0, torchdiffeq version is 0.1.0., CUDA version is 10.0.130, python version is 3.7.7.

    I noticed that in your implementation of adjoint method, you put the odeint under torch.no_grad while torchdyn did not.

    This is your code:

    class OdeintAdjointMethod(torch.autograd.Function):
    
        @staticmethod
        def forward(ctx, shapes, func, y0, t, rtol, atol, method, options, adjoint_rtol, adjoint_atol, adjoint_method,
                    adjoint_options, t_requires_grad, *adjoint_params):
    
            ctx.shapes = shapes
            ctx.func = func
            ctx.adjoint_rtol = adjoint_rtol
            ctx.adjoint_atol = adjoint_atol
            ctx.adjoint_method = adjoint_method
            ctx.adjoint_options = adjoint_options
            ctx.t_requires_grad = t_requires_grad
    
            with torch.no_grad():
                y = odeint(func, y0, t, rtol=rtol, atol=atol, method=method, options=options)
            ctx.save_for_backward(t, y, *adjoint_params)
            return y
    

    This is their code: (https://github.com/DiffEqML/torchdyn/blob/master/torchdyn/sensitivity/adjoint.py)

        def _define_autograd_adjoint(self):
            class autograd_adjoint(torch.autograd.Function):
                @staticmethod
                def forward(ctx, h0, flat_params, s_span):
                    sol = odeint(self.func, h0, self.s_span, rtol=self.rtol, atol=self.atol,
                                 method=self.method, options=self.options)
                    ctx.save_for_backward(self.s_span, self.flat_params, sol)
                    return sol[-1]
    
                @staticmethod
                def backward(ctx, *grad_output):
                    s, flat_params, sol = ctx.saved_tensors
                    self.f_params = tuple(self.func.parameters())
                    adj0 = self._init_adjoint_state(sol, grad_output)
                    adj_sol = odeint(self.adjoint_dynamics, adj0, self.s_span.flip(0),
                                   rtol=self.rtol, atol=self.atol, method=self.method, options=self.options)
                    λ = adj_sol[1]
                    μ = adj_sol[2]
                    return (λ, μ, None)
            return autograd_adjoint
    

    Also, I found that your FFJORD code also worked with single GPU but failed with multiple GPUs:

    截屏2020-09-11 上午11 37 35

    The running command is:

    export CUDA_VISIBLE_DEVICES=6,7
    python train_cnf.py --data mnist --dims 64,64,64 --strides 1,1,1,1 --num_blocks 2 --layer_type concat --multiscale True --rademacher True
    
    opened by AtlantixJJ 8
  • Compute loss on intermediate states (RK-45)

    Compute loss on intermediate states (RK-45)

    Ricky,

    what's the proper way to compute loss on intermediate states of RK-45?

    I assume it would involve storing the intermediate states that are in the trajectory computed by the adaptive solver instead of modifying the number of steps of the integration time.

    This in the context of a moving endpoint control problem: https://arxiv.org/abs/1805.07709

    opened by rafaelvalle 8
  • MNIST: ODEBlock possibly redundant?

    MNIST: ODEBlock possibly redundant?

    Hello,

    Thank you for your work. It introduces a very interesting concept.

    I have a question regarding your experimental section that acts as verification of the ODE model for MNIST classification.

    Your ODE MNIST model in the paper is the following model = nn.Sequential(*downsampling_layers, *feature_layers, *fc_layers).to(device), with the ODE block being in the middle as feature_layers and downsampling_method == conv. It has overall 208266 parameters and achieves test error of 0.42%.

    However, if you get rid of the middle block altogether and construct the following instead model = nn.Sequential(*downsampling_layers, *fc_layers).to(device), with downsampling_layers and fc_layers exactly as in the case before, you get a model with 132874 that achieves a similar test error of under 0.6% after roughly 100 epochs.

    Can it be that your experiment shows remarkable efficiency of your downsampling_layers rather than of the ODE block?

    Thanks,

    Simon

    opened by simonvary 8
  • Can you provide a detailed example or description of odeint_adjoint?

    Can you provide a detailed example or description of odeint_adjoint?

    Hi, My requirement is to use odeint_adjoint to return the gradient of the loss to a quantity in the neural network. However, the information I get from the documentation is not a good guide to using odeint_adjoint. Thank you!!!

    opened by rid-sun 0
  • Can torchdiffeq provide support for stiff ordinary differential equations?

    Can torchdiffeq provide support for stiff ordinary differential equations?

    I am currently working on some tasks on the derivation of chemical reaction equations and I have found that torchdiffeq does not seem to provide support for some solvers such as ode23tb and ode23s in matlab for stiff ordinary differential equations. My program seems to perform poorly with fixed step solvers. Is there a good alternative?

    opened by Shirley-YFY 0
  • ComplexFloat implementation

    ComplexFloat implementation

    Is there a way to use vectors and diff eqs with complex coefficients? Thanks!

    a = torch.FloatTensor([[1,0],[0,1]])
    y0 = torch.FloatTensor([1, 0])
    def f(t, y):
            return -1.j * torch.matmul(a, y)
    t_list = torch.linspace(0, 1, 11)
    odeint(f, y0, t_list)
    

    RuntimeError: "lt_cpu" not implemented for 'ComplexFloat'

    opened by itrosen 0
  • Question about latent_ode.py noise_std?

    Question about latent_ode.py noise_std?

    On line 266, when computing the loss, why can we use noise_std directly, I mean noise_std is a pre-defined parameter for us to generate our original (true) trajectories, when we calculate loss, we should not use any pre-defined true parameter, is that correct?

    opened by Yunyi-learner 0
  • adjoint method breaks after reaching certain performance

    adjoint method breaks after reaching certain performance

    Hi,

    I've been using this library for a image-based flow estimation task. The way I do it is using an ode solver to solve for an evolving spatial transformation by passing an inital zero flow and some additional feature to odeint/adjoint_odeint. I used a convolutional based neural network. My code looks roughly like this:

    class model(nn.Module):
        def __init__():
            self.encoder = # a image feature extractor
            self.odefunc = CNN()
    
       def forward(x):
            x = self.encoder(x)
            ode_x = torch.cat([x, zero_flow], 1) 
            ode_y = odeint_adj(self.odefunc, ode_x)
            flow = ode_y[:, :, -flow_dim:]
            return flow
    
    class CNN(nn.module):
        def forward(t, x):
            delta_flow = self.layers(x) # run through all layers
            return torch.cat([torch.zeros(), delta_flow], 1) # match input dimension, features remain static
    

    My network converges with non-adjoint method. But when I was using adjoint method, the model converged initalially but losses would always explode after reaching certain performance(The network produces quite accurate result before it breaks). The loss functions I used are typical image similarity loss and a regularization loss. I used a fixed-step euler solver. Do you know what could be the reasons for this? I highly appreaciate any suggestion!

    opened by Lancial 0
Owner
Ricky Chen
Ricky Chen
Fast, general, and tested differentiable structured prediction in PyTorch

Torch-Struct: Structured Prediction Library A library of tested, GPU implementations of core structured prediction algorithms for deep learning applic

HNLP 1.1k Jan 7, 2023
This is an differentiable pytorch implementation of SIFT patch descriptor.

This is an differentiable pytorch implementation of SIFT patch descriptor. It is very slow for describing one patch, but quite fast for batch. It can

Dmytro Mishkin 150 Dec 24, 2022
OptNet: Differentiable Optimization as a Layer in Neural Networks

OptNet: Differentiable Optimization as a Layer in Neural Networks This repository is by Brandon Amos and J. Zico Kolter and contains the PyTorch sourc

CMU Locus Lab 428 Dec 24, 2022
A simple way to train and use PyTorch models with multi-GPU, TPU, mixed-precision

?? Accelerate was created for PyTorch users who like to write the training loop of PyTorch models but are reluctant to write and maintain the boilerplate code needed to use multi-GPUs/TPU/fp16.

Hugging Face 3.5k Jan 8, 2023
A very simple and small path tracer written in pytorch meant to be run on the GPU

MentisOculi Pytorch Path Tracer A very simple and small path tracer written in pytorch meant to be run on the GPU Why use pytorch and not some other c

Matthew B. Mirman 222 Dec 1, 2022
GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks

GPU-accelerated PyTorch implementation of Zero-shot User Intent Detection via Capsule Neural Networks This repository implements a capsule model Inten

Joel Huang 15 Dec 24, 2022
PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

PyTorch framework A simple and complete framework for PyTorch, providing a variety of data loading and simple task solutions that are easy to extend and migrate

Cong Cai 12 Dec 19, 2021
The easiest way to use deep metric learning in your application. Modular, flexible, and extensible. Written in PyTorch.

News March 3: v0.9.97 has various bug fixes and improvements: Bug fixes for NTXentLoss Efficiency improvement for AccuracyCalculator, by using torch i

Kevin Musgrave 5k Jan 2, 2023
A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch

Torchmeta A collection of extensions and data-loaders for few-shot learning & meta-learning in PyTorch. Torchmeta contains popular meta-learning bench

Tristan Deleu 1.7k Jan 6, 2023
Pretrained EfficientNet, EfficientNet-Lite, MixNet, MobileNetV3 / V2, MNASNet A1 and B1, FBNet, Single-Path NAS

(Generic) EfficientNets for PyTorch A 'generic' implementation of EfficientNet, MixNet, MobileNetV3, etc. that covers most of the compute/parameter ef

Ross Wightman 1.5k Jan 1, 2023
PyTorch extensions for fast R&D prototyping and Kaggle farming

Pytorch-toolbelt A pytorch-toolbelt is a Python library with a set of bells and whistles for PyTorch for fast R&D prototyping and Kaggle farming: What

Eugene Khvedchenya 1.3k Jan 5, 2023
Tez is a super-simple and lightweight Trainer for PyTorch. It also comes with many utils that you can use to tackle over 90% of deep learning projects in PyTorch.

Tez: a simple pytorch trainer NOTE: Currently, we are not accepting any pull requests! All PRs will be closed. If you want a feature or something does

abhishek thakur 1.1k Jan 4, 2023
A tiny scalar-valued autograd engine and a neural net library on top of it with PyTorch-like API

micrograd A tiny Autograd engine (with a bite! :)). Implements backpropagation (reverse-mode autodiff) over a dynamically built DAG and a small neural

Andrej 3.5k Jan 8, 2023
A simplified framework and utilities for PyTorch

Here is Poutyne. Poutyne is a simplified framework for PyTorch and handles much of the boilerplating code needed to train neural networks. Use Poutyne

GRAAL/GRAIL 534 Dec 17, 2022
An optimizer that trains as fast as Adam and as good as SGD.

AdaBound An optimizer that trains as fast as Adam and as good as SGD, for developing state-of-the-art deep learning models on a wide variety of popula

LoLo 2.9k Dec 27, 2022
pip install antialiased-cnns to improve stability and accuracy

Antialiased CNNs [Project Page] [Paper] [Talk] Making Convolutional Networks Shift-Invariant Again Richard Zhang. In ICML, 2019. Quick & easy start Ru

Adobe, Inc. 1.6k Dec 28, 2022
Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Kaldi-compatible feature extraction with PyTorch, supporting CUDA, batch processing, chunk processing, and autograd

Fangjun Kuang 119 Jan 3, 2023
A pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch.

Compact Bilinear Pooling for PyTorch. This repository has a pure Python implementation of Compact Bilinear Pooling and Count Sketch for PyTorch. This

Grégoire Payen de La Garanderie 234 Dec 7, 2022
PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

PyTorch Lightning Optical Flow models, scripts, and pretrained weights.

Henrique Morimitsu 105 Dec 16, 2022