Pytorch implementation of i-RevNets.

i-RevNets define a family of fully invertible deep networks, built from a succession of homeomorphic layers.

Reference: Jörn-Henrik Jacobsen, Arnold Smeulders, Edouard Oyallon. i-RevNet: Deep Invertible Networks. International Conference on Learning Representations (ICLR), 2018. (https://iclr.cc/)


The i-RevNet and its dual. The inverse can be obtained from the forward model with minimal adaption and is an i-RevNet as well. Read the paper for theoretical background and detailed analysis of the trained models.

Pytorch i-RevNet Usage

Requirements: Python 3, Numpy, Pytorch, Torchvision

Download the ImageNet dataset and move validation images to labeled subfolders. To do this, you can use the following script: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh

We provide an Imagenet pre-trained model: Download
Save it to this folder.

Train small i-RevNet on Cifar-10, takes about 5 hours and yields an accuracy of ~94.5%

$ python CIFAR_main.py --nBlocks 18 18 18 --nStrides 1 2 2 --nChannels 16 64 256

Train bijective i-RevNet on Imagenet, takes 7-10 days and yields top-1 accuracy of ~74%

$ python ILSVRC_main.py --data /path/to/ILSVRC2012/ --nBlocks 6 16 72 6 --nStrides 2 2 2 2 --nChannels 24 96 384 1536 --init_ds 2

Evaluate pre-trained model on Imagenet validation set, yields 74.018% top-1 accuracy

$ bash scripts/evaluate_ilsvrc-2012.sh

Invert output of last layer on Imagenet validation set and save example images

$ bash scripts/invert_ilsvrc-2012.sh

Imagenet ILSVRC-2012 Results

i-RevNets perform on par with baseline RevNet and ResNet.

Model: ResNet RevNet i-RevNet (a) i-RevNet (b)
Val Top-1 Error: 24.7 25.2 24.7 26.0

Reconstructions from ILSVRC-2012 validation set. Top row original image, bottom row reconstruction from final representation.



Contributions are very welcome.


title={i-RevNet: Deep Invertible Networks},
author={Jörn-Henrik Jacobsen and Arnold W.M. Smeulders and Edouard Oyallon},
booktitle={International Conference on Learning Representations},
  • Consider more efficient implementation of class psi

    Consider more efficient implementation of class psi

    I found that replacing the original implementation of models.model_utils.psi with the following implementation gave me about an order of magnitude speed-up, both in forward() and inverse(), both on the GPU and CPU:

    class psi_suggested(psi):
        def inverse(self, inpt):
            bl, bl_sq = self.block_size, self.block_size_sq
            bs, new_d, h, w = inpt.shape[0], inpt.shape[1] // bl_sq, inpt.shape[2], inpt.shape[3]
            return inpt.view(bs, bl, bl, new_d, h, w).permute(0, 3, 4, 1, 5, 2).reshape(bs, new_d, h * bl, w * bl)
        def forward(self, inpt):
            bl, bl_sq = self.block_size, self.block_size_sq
            bs, d, new_h, new_w = inpt.shape[0], inpt.shape[1], inpt.shape[2] // bl, inpt.shape[3] // bl
            return inpt.view(bs, d, new_h, bl, new_w, bl).permute(0, 3, 5, 1, 2, 4).reshape(bs, d * bl_sq, new_h, new_w)

    I timed it as follows:

    import timeit
    device = torch.device("cpu")
    # Forward
    t = torch.randn(64, 5, 192, 192, dtype=torch.float32).to(device)
    psi_instance = psi(block_size)
    psi_callable = lambda: psi_instance.forward(t)
    psi_suggested_instance = psi_suggested(block_size)
    psi_suggested_callable = lambda: psi_suggested_instance.forward(t)
    print("Timing forward, suggested psi:", timeit.Timer(psi_suggested_callable).timeit(100))
    print("Timing forward, original psi:", timeit.Timer(psi_callable).timeit(100))
    print("Same result in forward?", (psi_callable() == psi_suggested_callable()).all().item())
    # Inverse
    t = torch.randn(64, 45, 64, 64, dtype=torch.float32).to(device)
    psi_instance = psi(block_size)
    psi_callable = lambda: psi_instance.inverse(t)
    psi_suggested_instance = psi_suggested(block_size)
    psi_suggested_callable = lambda: psi_suggested_instance.inverse(t)
    print("Timing inverse, suggested psi:", timeit.Timer(psi_suggested_callable).timeit(100))
    print("Timing inverse, original psi:", timeit.Timer(psi_callable).timeit(100))
    print("Same result in inverse?", (psi_callable() == psi_suggested_callable()).all().item())

    Which gave me on the CPU:

    Timing forward, suggested psi: 1.7428924000000734
    Timing forward, original psi: 11.567388500000106
    Same result in forward? True
    Timing inverse, suggested psi: 2.7421231000000716
    Timing inverse, original psi: 10.155058100000133
    Same result in inverse? True

    And on the GPU:

    Timing forward, suggested psi: 0.010811799999828509
    Timing forward, original psi: 0.5521851000000879
    Same result in forward? True
    Timing inverse, suggested psi: 0.010437099999990096
    Timing inverse, original psi: 0.07210720000011861
    Same result in inverse? True

    If you can reproduce these results, you might consider reimplementing psi.

  • iRevNet.py's test code doesn't work

    iRevNet.py's test code doesn't work


    I tried running the script you provide in models.iRevNet.py:

    model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2],
                        nChannels=None, nClasses=1000, init_ds=2,
                        dropout_rate=0., affineBN=True, in_shape=[3, 224, 224],
    y = model(Variable(torch.randn(1, 3, 224, 224)))

    However, this seems to raise an error:

     == Building iRevNet 301 == 
    Traceback (most recent call last):
      File "iRevNet.py", line 158, in <module>
        y = model(Variable(torch.randn(1, 3, 224, 224)))
      File "/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
        result = self.forward(*input, **kwargs)
      File "iRevNet.py", line 132, in forward
        out = block.forward(out)
      File "iRevNet.py", line 61, in forward
        y1 = Fx2 + x1
    RuntimeError: The size of tensor a (6) must match the size of tensor b (24) at non-singleton dimension 1

    Am I doing something wrong?


  • Pretrained models cannot be loaded

    Pretrained models cannot be loaded

    Hi, I have a question about the pretrained models.

    Does it has the same architecture as the ILSVRC example? I can't load it to your example model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], nChannels=[24, 96, 384, 1536], nClasses=1000, init_ds=2, dropout_rate=0., affineBN=True, in_shape=[3, 224, 224], mult=4)

    and when I tried the key "arch" in your saved state, it is written "resnet18", If it is different, would you explain the architecture so that I can use your pretrained model?

    Thanks for your help.

  • About model size of i-RevNet (b)

    About model size of i-RevNet (b)


    As illustrated in Table 1 of your ICLR 2018 paper, the number of parameters of i-RevNet (b) is 29M. However, when I took the released iRevNet.py in this repository and compute the model size, I found the size I've get is different from that in the paper. The number of paramters of this medel I've get is about 125.12MB, which is significantly larger than the supposed size. And the tool I take for model size computing is from https://github.com/Lyken17/pytorch-OpCounter. I feel puzzled about the model size. Could you help me make it clear?

    Best Regards, Jiajun Deng

  • is it fully invertible? last layer is pooling+ linear

    is it fully invertible? last layer is pooling+ linear

    Seems that the network is only 90% invertible. last layer relies on pooling and linear layer. Can we replace pooling+relu+linear layer with more rev convnet downsampling and 1x1 reversible convolution? @jhjacobsen did you tested this?

  • Cifar exception in tensor size at beginning of training

    Cifar exception in tensor size at beginning of training

    In utils_cifar.py line 105 (and 133): train_loss += loss.data[0]

    gives an error: IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

  • Hi, I  don't understand permute function

    Hi, I don't understand permute function

    Hello, @jhjacobsen

    I am a newcomer to deep learning. I read your source code and there are a few things I don’t understand. Why is the permute function so used?

    class injective_pad(nn.Module):
        def __init__(self, pad_size):
            super(injective_pad, self).__init__()
            self.pad_size = pad_size
            self.pad = nn.ZeroPad2d((0, 0, 0, pad_size))
        def forward(self, x):
            x = x.permute(0, 2, 1, 3)
            x = self.pad(x)
            return x.permute(0, 2, 1, 3)
        def inverse(self, x):
            return x[:, :x.size(1) - self.pad_size, :, :]

    And this,

    class psi(nn.Module):
        def __init__(self, block_size):
            super(psi, self).__init__()
            self.block_size = block_size
            self.block_size_sq = block_size*block_size
        def inverse(self, input):
            output = input.permute(0, 2, 3, 1)
            (batch_size, d_height, d_width, d_depth) = output.size()
            s_depth = int(d_depth / self.block_size_sq)
            s_width = int(d_width * self.block_size)
            s_height = int(d_height * self.block_size)
            t_1 = output.contiguous().view(batch_size, d_height, d_width, self.block_size_sq, s_depth)
            spl = t_1.split(self.block_size, 3)
            stack = [t_t.contiguous().view(batch_size, d_height, s_width, s_depth) for t_t in spl]
            output = torch.stack(stack, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).contiguous().view(batch_size, s_height, s_width, s_depth)
            output = output.permute(0, 3, 1, 2)
            return output.contiguous()
        def forward(self, input):
            output = input.permute(0, 2, 3, 1)
            (batch_size, s_height, s_width, s_depth) = output.size()
            d_depth = s_depth * self.block_size_sq
            d_height = int(s_height / self.block_size)
            t_1 = output.split(self.block_size, 2)
            stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1]
            output = torch.stack(stack, 1)
            output = output.permute(0, 2, 1, 3)
            output = output.permute(0, 3, 1, 2)
            return output.contiguous()

    What is the role of these two class functions? Thank you.

  • Inverse function has zero block size for final psi function initialization

    Inverse function has zero block size for final psi function initialization

    The issue refers to this line in the code. https://github.com/jhjacobsen/pytorch-i-revnet/blob/c21afaebca0c7dd81c17c0c2ddf1e19979fa5448/models/iRevNet.py#L146

    When using the code as is for CIFAR10, I get a ZeroDivisionError when computing the inverse.

    This is due to a call to a psi function that has an invalid initialisation, i.e. self.init_ds = 0.

    If I comment out this line the inverse seems to be computed correctly.

    Really cool work by the way!

  • Potential bug in model_utils.py

    Potential bug in model_utils.py

    Thanks for the great work. I think I found a potential bug in your code


    Shouldn't it be return x[:, :x.size(1) - self.pad_size, :, :]?

  • The size of tensor are not identifed

    The size of tensor are not identifed

    when I run the code in the github, such as import torch from torch.autograd import Variable from models.iRevNet import iRevNet

    model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], nChannels=None, nClasses=1000, init_ds=2, dropout_rate=0., affineBN=True, in_shape=[3, 224, 224], mult=4) y = model(Variable(torch.randn(1, 3, 32, 32))) print(y.size())
    The errors are: y1 = Fx2 + x1 RuntimeError: The size of tensor a (6) must match the size of tensor b (24) at non-singleton dimension 1 When debug it, I found the shape of Fx2 is 1X6X8X8, the shape of x1 is 1X24X8X8, how to fix it, please help.

  • Question: memory saving

    Question: memory saving

    Thanks for open sourcing your code and congrats on the paper!

    From the paper: "For the same reasons as in Gomez et al. (2017), our scheme also allows avoiding storing any intermediate activations at training time, making memory consumption for very deep i-RevNets not an issue in practice.

    I was wondering where in the code was this happening, I was expecting some backward functions implementing this.

    Thank you, Ignacio

  • Some confusion for case 'stride = 2'

    Some confusion for case 'stride = 2'

    Dear authors:

    The forward procedure for $i$-RevNet described in the paper (Eq.(1)) is:

    $$ \tilde{x}{j+1} = x{j} + F_{j+1} \tilde{x}_{j} $$

    However, the code for case 'stride = 2' leads to the following form:

    class irevnet_block(nn.Module):
        def forward(self, x):
            """ bijective or injective block forward """
            if self.pad != 0 and self.stride == 1:
                x = merge(x[0], x[1])
                x = self.inj_pad.forward(x)
                x1, x2 = split(x)
                x = (x1, x2)
            x1 = x[0]
            x2 = x[1]
            Fx2 = self.bottleneck_block(x2)
            if self.stride == 2:
                x1 = self.psi.forward(x1)
                x2 = self.psi.forward(x2)
            y1 = Fx2 + x1
            return (x2, y1)

    which means

    $$ \tilde{x}{j+1} = {S}{j+1}x_{j} + F_{j+1} \tilde{x}_{j} $$

    Whether I understand correctly? It is appreciated that answering my question in your busy time.

  • question about --nChannels

    question about --nChannels

    Dear authors,

    Thank you for your great work. The output Channels in resnet50 is [256,512,1024,2048],Could you please explain why you set --nChannels as [24,96,384,1536],instead of [128,256,512,1024] to match the resnet50,thanks!!!

  • Questions about using the i-resnet for other applications.

    Questions about using the i-resnet for other applications.

    Dear authors,

    Thank you for your great work. Currently, I'm working on the application of semantic segmentation. So I wonder whether the i-resnet can be directly applied by semantic segmentation by simply changing the final classification layer (i.e. modified the following four lines). Or do u have other suggestions?


    Thank you very much for your help in advance.

