i-RevNet Pytorch Code

Overview

i-RevNet: Deep Invertible Networks

Pytorch implementation of i-RevNets.

i-RevNets define a family of fully invertible deep networks, built from a succession of homeomorphic layers.

Reference: Jörn-Henrik Jacobsen, Arnold Smeulders, Edouard Oyallon. i-RevNet: Deep Invertible Networks. International Conference on Learning Representations (ICLR), 2018. (https://iclr.cc/)

Algorithm

The i-RevNet and its dual. The inverse can be obtained from the forward model with minimal adaption and is an i-RevNet as well. Read the paper for theoretical background and detailed analysis of the trained models.

Pytorch i-RevNet Usage

Requirements: Python 3, Numpy, Pytorch, Torchvision

Download the ImageNet dataset and move validation images to labeled subfolders. To do this, you can use the following script: https://raw.githubusercontent.com/soumith/imagenetloader.torch/master/valprep.sh

We provide an Imagenet pre-trained model: Download
Save it to this folder.

Train small i-RevNet on Cifar-10, takes about 5 hours and yields an accuracy of ~94.5%

$ python CIFAR_main.py --nBlocks 18 18 18 --nStrides 1 2 2 --nChannels 16 64 256

Train bijective i-RevNet on Imagenet, takes 7-10 days and yields top-1 accuracy of ~74%

$ python ILSVRC_main.py --data /path/to/ILSVRC2012/ --nBlocks 6 16 72 6 --nStrides 2 2 2 2 --nChannels 24 96 384 1536 --init_ds 2

Evaluate pre-trained model on Imagenet validation set, yields 74.018% top-1 accuracy

$ bash scripts/evaluate_ilsvrc-2012.sh

Invert output of last layer on Imagenet validation set and save example images

$ bash scripts/invert_ilsvrc-2012.sh

Imagenet ILSVRC-2012 Results

i-RevNets perform on par with baseline RevNet and ResNet.

Model: ResNet RevNet i-RevNet (a) i-RevNet (b)
Val Top-1 Error: 24.7 25.2 24.7 26.0

Reconstructions from ILSVRC-2012 validation set. Top row original image, bottom row reconstruction from final representation.

Inverse

Contribute

Contributions are very welcome.

Cite

@inproceedings{
jacobsen2018irevnet,
title={i-RevNet: Deep Invertible Networks},
author={Jörn-Henrik Jacobsen and Arnold W.M. Smeulders and Edouard Oyallon},
booktitle={International Conference on Learning Representations},
year={2018},
url={https://openreview.net/forum?id=HJsjkMb0Z},
}
Comments
  • Consider more efficient implementation of class psi

    Consider more efficient implementation of class psi

    I found that replacing the original implementation of models.model_utils.psi with the following implementation gave me about an order of magnitude speed-up, both in forward() and inverse(), both on the GPU and CPU:

    class psi_suggested(psi):
        
        def inverse(self, inpt):
            bl, bl_sq = self.block_size, self.block_size_sq
            bs, new_d, h, w = inpt.shape[0], inpt.shape[1] // bl_sq, inpt.shape[2], inpt.shape[3]
            return inpt.view(bs, bl, bl, new_d, h, w).permute(0, 3, 4, 1, 5, 2).reshape(bs, new_d, h * bl, w * bl)
        
        def forward(self, inpt):
            bl, bl_sq = self.block_size, self.block_size_sq
            bs, d, new_h, new_w = inpt.shape[0], inpt.shape[1], inpt.shape[2] // bl, inpt.shape[3] // bl
            return inpt.view(bs, d, new_h, bl, new_w, bl).permute(0, 3, 5, 1, 2, 4).reshape(bs, d * bl_sq, new_h, new_w)
    

    I timed it as follows:

    import timeit
    device = torch.device("cpu")
    
    # Forward
    block_size=3
    t = torch.randn(64, 5, 192, 192, dtype=torch.float32).to(device)
    psi_instance = psi(block_size)
    psi_callable = lambda: psi_instance.forward(t)
    psi_suggested_instance = psi_suggested(block_size)
    psi_suggested_callable = lambda: psi_suggested_instance.forward(t)
    print("Timing forward, suggested psi:", timeit.Timer(psi_suggested_callable).timeit(100))
    print("Timing forward, original psi:", timeit.Timer(psi_callable).timeit(100))
    print("Same result in forward?", (psi_callable() == psi_suggested_callable()).all().item())
    
    # Inverse
    block_size=3
    t = torch.randn(64, 45, 64, 64, dtype=torch.float32).to(device)
    psi_instance = psi(block_size)
    psi_callable = lambda: psi_instance.inverse(t)
    psi_suggested_instance = psi_suggested(block_size)
    psi_suggested_callable = lambda: psi_suggested_instance.inverse(t)
    print("Timing inverse, suggested psi:", timeit.Timer(psi_suggested_callable).timeit(100))
    print("Timing inverse, original psi:", timeit.Timer(psi_callable).timeit(100))
    print("Same result in inverse?", (psi_callable() == psi_suggested_callable()).all().item())
    

    Which gave me on the CPU:

    Timing forward, suggested psi: 1.7428924000000734
    Timing forward, original psi: 11.567388500000106
    Same result in forward? True
    Timing inverse, suggested psi: 2.7421231000000716
    Timing inverse, original psi: 10.155058100000133
    Same result in inverse? True
    

    And on the GPU:

    Timing forward, suggested psi: 0.010811799999828509
    Timing forward, original psi: 0.5521851000000879
    Same result in forward? True
    Timing inverse, suggested psi: 0.010437099999990096
    Timing inverse, original psi: 0.07210720000011861
    Same result in inverse? True
    

    If you can reproduce these results, you might consider reimplementing psi.

    opened by spezold 4
  • iRevNet.py's test code doesn't work

    iRevNet.py's test code doesn't work

    Hi,

    I tried running the script you provide in models.iRevNet.py:

    model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2],
                        nChannels=None, nClasses=1000, init_ds=2,
                        dropout_rate=0., affineBN=True, in_shape=[3, 224, 224],
                        mult=4)
    y = model(Variable(torch.randn(1, 3, 224, 224)))
    print(y.size())
    

    However, this seems to raise an error:

     == Building iRevNet 301 == 
    Traceback (most recent call last):
      File "iRevNet.py", line 158, in <module>
        y = model(Variable(torch.randn(1, 3, 224, 224)))
      File "/anaconda/lib/python3.6/site-packages/torch/nn/modules/module.py", line 477, in __call__
        result = self.forward(*input, **kwargs)
      File "iRevNet.py", line 132, in forward
        out = block.forward(out)
      File "iRevNet.py", line 61, in forward
        y1 = Fx2 + x1
    RuntimeError: The size of tensor a (6) must match the size of tensor b (24) at non-singleton dimension 1
    

    Am I doing something wrong?

    Thanks!

    opened by Selim78 4
  • Pretrained models cannot be loaded

    Pretrained models cannot be loaded

    Hi, I have a question about the pretrained models.

    Does it has the same architecture as the ILSVRC example? I can't load it to your example model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], nChannels=[24, 96, 384, 1536], nClasses=1000, init_ds=2, dropout_rate=0., affineBN=True, in_shape=[3, 224, 224], mult=4)

    and when I tried the key "arch" in your saved state, it is written "resnet18", If it is different, would you explain the architecture so that I can use your pretrained model?

    Thanks for your help.

    opened by effendijohanes 3
  • About model size of i-RevNet (b)

    About model size of i-RevNet (b)

    Hi:

    As illustrated in Table 1 of your ICLR 2018 paper, the number of parameters of i-RevNet (b) is 29M. However, when I took the released iRevNet.py in this repository and compute the model size, I found the size I've get is different from that in the paper. The number of paramters of this medel I've get is about 125.12MB, which is significantly larger than the supposed size. And the tool I take for model size computing is from https://github.com/Lyken17/pytorch-OpCounter. I feel puzzled about the model size. Could you help me make it clear?

    Best Regards, Jiajun Deng

    opened by djiajunustc 2
  • is it fully invertible? last layer is pooling+ linear

    is it fully invertible? last layer is pooling+ linear

    Seems that the network is only 90% invertible. last layer relies on pooling and linear layer. Can we replace pooling+relu+linear layer with more rev convnet downsampling and 1x1 reversible convolution? @jhjacobsen did you tested this?

    opened by tsauri 2
  • Cifar exception in tensor size at beginning of training

    Cifar exception in tensor size at beginning of training

    In utils_cifar.py line 105 (and 133): train_loss += loss.data[0]

    gives an error: IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number

    opened by david-klindt 1
  • Hi, I  don't understand permute function

    Hi, I don't understand permute function

    Hello, @jhjacobsen

    I am a newcomer to deep learning. I read your source code and there are a few things I don’t understand. Why is the permute function so used?

    class injective_pad(nn.Module):
        def __init__(self, pad_size):
            super(injective_pad, self).__init__()
            self.pad_size = pad_size
            self.pad = nn.ZeroPad2d((0, 0, 0, pad_size))
    
        def forward(self, x):
            x = x.permute(0, 2, 1, 3)
            x = self.pad(x)
            return x.permute(0, 2, 1, 3)
    
        def inverse(self, x):
            return x[:, :x.size(1) - self.pad_size, :, :]
    
    

    And this,

    class psi(nn.Module):
        def __init__(self, block_size):
            super(psi, self).__init__()
            self.block_size = block_size
            self.block_size_sq = block_size*block_size
    
        def inverse(self, input):
            output = input.permute(0, 2, 3, 1)
            (batch_size, d_height, d_width, d_depth) = output.size()
            s_depth = int(d_depth / self.block_size_sq)
            s_width = int(d_width * self.block_size)
            s_height = int(d_height * self.block_size)
            t_1 = output.contiguous().view(batch_size, d_height, d_width, self.block_size_sq, s_depth)
            spl = t_1.split(self.block_size, 3)
            stack = [t_t.contiguous().view(batch_size, d_height, s_width, s_depth) for t_t in spl]
            output = torch.stack(stack, 0).transpose(0, 1).permute(0, 2, 1, 3, 4).contiguous().view(batch_size, s_height, s_width, s_depth)
            output = output.permute(0, 3, 1, 2)
            return output.contiguous()
    
        def forward(self, input):
            output = input.permute(0, 2, 3, 1)
            (batch_size, s_height, s_width, s_depth) = output.size()
            d_depth = s_depth * self.block_size_sq
            d_height = int(s_height / self.block_size)
            t_1 = output.split(self.block_size, 2)
            stack = [t_t.contiguous().view(batch_size, d_height, d_depth) for t_t in t_1]
            output = torch.stack(stack, 1)
            output = output.permute(0, 2, 1, 3)
            output = output.permute(0, 3, 1, 2)
            return output.contiguous()
    

    What is the role of these two class functions? Thank you.

    opened by MrLinNing 1
  • Inverse function has zero block size for final psi function initialization

    Inverse function has zero block size for final psi function initialization

    The issue refers to this line in the code. https://github.com/jhjacobsen/pytorch-i-revnet/blob/c21afaebca0c7dd81c17c0c2ddf1e19979fa5448/models/iRevNet.py#L146

    When using the code as is for CIFAR10, I get a ZeroDivisionError when computing the inverse.

    This is due to a call to a psi function that has an invalid initialisation, i.e. self.init_ds = 0.

    If I comment out this line the inverse seems to be computed correctly.

    Really cool work by the way!

    opened by ThorJonsson 1
  • Potential bug in model_utils.py

    Potential bug in model_utils.py

    Thanks for the great work. I think I found a potential bug in your code

    https://github.com/jhjacobsen/pytorch-i-revnet/blob/master/models/model_utils.py#L36

    Shouldn't it be return x[:, :x.size(1) - self.pad_size, :, :]?

    opened by danfeiX 1
  • The size of tensor are not identifed

    The size of tensor are not identifed

    when I run the code in the github, such as import torch from torch.autograd import Variable from models.iRevNet import iRevNet

    model = iRevNet(nBlocks=[6, 16, 72, 6], nStrides=[2, 2, 2, 2], nChannels=None, nClasses=1000, init_ds=2, dropout_rate=0., affineBN=True, in_shape=[3, 224, 224], mult=4) y = model(Variable(torch.randn(1, 3, 32, 32))) print(y.size())
    The errors are: y1 = Fx2 + x1 RuntimeError: The size of tensor a (6) must match the size of tensor b (24) at non-singleton dimension 1 When debug it, I found the shape of Fx2 is 1X6X8X8, the shape of x1 is 1X24X8X8, how to fix it, please help.

    opened by chouqin3 1
  • Question: memory saving

    Question: memory saving

    Thanks for open sourcing your code and congrats on the paper!

    From the paper: "For the same reasons as in Gomez et al. (2017), our scheme also allows avoiding storing any intermediate activations at training time, making memory consumption for very deep i-RevNets not an issue in practice.

    I was wondering where in the code was this happening, I was expecting some backward functions implementing this.

    Thank you, Ignacio

    opened by tartavull 1
  • Some confusion for case 'stride = 2'

    Some confusion for case 'stride = 2'

    Dear authors:

    The forward procedure for $i$-RevNet described in the paper (Eq.(1)) is:

    $$ \tilde{x}{j+1} = x{j} + F_{j+1} \tilde{x}_{j} $$

    However, the code for case 'stride = 2' leads to the following form:

    class irevnet_block(nn.Module):
    ...
        def forward(self, x):
            """ bijective or injective block forward """
            if self.pad != 0 and self.stride == 1:
                x = merge(x[0], x[1])
                x = self.inj_pad.forward(x)
                x1, x2 = split(x)
                x = (x1, x2)
            x1 = x[0]
            x2 = x[1]
            Fx2 = self.bottleneck_block(x2)
            if self.stride == 2:
                x1 = self.psi.forward(x1)
                x2 = self.psi.forward(x2)
            y1 = Fx2 + x1
            return (x2, y1)
    

    which means

    $$ \tilde{x}{j+1} = {S}{j+1}x_{j} + F_{j+1} \tilde{x}_{j} $$

    Whether I understand correctly? It is appreciated that answering my question in your busy time.

    opened by shuizidesu 0
  • question about --nChannels

    question about --nChannels

    Dear authors,

    Thank you for your great work. The output Channels in resnet50 is [256,512,1024,2048],Could you please explain why you set --nChannels as [24,96,384,1536],instead of [128,256,512,1024] to match the resnet50,thanks!!!

    opened by eefnn 0
  • Questions about using the i-resnet for other applications.

    Questions about using the i-resnet for other applications.

    Dear authors,

    Thank you for your great work. Currently, I'm working on the application of semantic segmentation. So I wonder whether the i-resnet can be directly applied by semantic segmentation by simply changing the final classification layer (i.e. modified the following four lines). Or do u have other suggestions?

    https://github.com/jhjacobsen/pytorch-i-revnet/blob/307413043e33540cbe9c3746ef420261f8138315/models/iRevNet.py#L134-L137

    Thank you very much for your help in advance.

    opened by xyIsHere 1
Owner
Jörn Jacobsen
j.jacobsen [at] vectorinstitute.ai
Jörn Jacobsen
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 360 Dec 10, 2022
Amazon Forest Computer Vision: Satellite Image tagging code using PyTorch / Keras with lots of PyTorch tricks

Amazon Forest Computer Vision Satellite Image tagging code using PyTorch / Keras Here is a sample of images we had to work with Source: https://www.ka

Mamy Ratsimbazafy 359 Jan 5, 2023
A code generator from ONNX to PyTorch code

onnx-pytorch Generating pytorch code from ONNX. Currently support onnx==1.9.0 and torch==1.8.1. Installation From PyPI pip install onnx-pytorch From

Wenhao Hu 94 Jan 6, 2023
An essential implementation of BYOL in PyTorch + PyTorch Lightning

Essential BYOL A simple and complete implementation of Bootstrap your own latent: A new approach to self-supervised Learning in PyTorch + PyTorch Ligh

Enrico Fini 48 Sep 27, 2022
RealFormer-Pytorch Implementation of RealFormer using pytorch

RealFormer-Pytorch Implementation of RealFormer using pytorch. Includes comparison with classical Transformer on image classification task (ViT) wrt C

Simo Ryu 90 Dec 8, 2022
Generic template to bootstrap your PyTorch project with PyTorch Lightning, Hydra, W&B, and DVC.

NN Template Generic template to bootstrap your PyTorch project. Click on Use this Template and avoid writing boilerplate code for: PyTorch Lightning,

Luca Moschella 520 Dec 30, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

This repository holds NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pytorch. Some of the code here will be included in upstream Pytorch eventually. The intention of Apex is to make up-to-date utilities available to users as quickly as possible.

NVIDIA Corporation 6.9k Jan 3, 2023
Objective of the repository is to learn and build machine learning models using Pytorch. 30DaysofML Using Pytorch

30 Days Of Machine Learning Using Pytorch Objective of the repository is to learn and build machine learning models using Pytorch. List of Algorithms

Mayur 119 Nov 24, 2022
Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pretrained SOTA Deep Learning models, callbacks and more for research and production with PyTorch Lightning and PyTorch

Pytorch Lightning 1.4k Jan 1, 2023
The Incredible PyTorch: a curated list of tutorials, papers, projects, communities and more relating to PyTorch.

This is a curated list of tutorials, projects, libraries, videos, papers, books and anything related to the incredible PyTorch. Feel free to make a pu

Ritchie Ng 9.2k Jan 2, 2023
A PyTorch implementation of the paper Mixup: Beyond Empirical Risk Minimization in PyTorch

Mixup: Beyond Empirical Risk Minimization in PyTorch This is an unofficial PyTorch implementation of mixup: Beyond Empirical Risk Minimization. The co

Harry Yang 121 Dec 17, 2022
A pytorch implementation of Pytorch-Sketch-RNN

Pytorch-Sketch-RNN A pytorch implementation of https://arxiv.org/abs/1704.03477 In order to draw other things than cats, you will find more drawing da

Alexis David Jacq 172 Dec 12, 2022
PyTorch implementation of Advantage async actor-critic Algorithms (A3C) in PyTorch

Advantage async actor-critic Algorithms (A3C) in PyTorch @inproceedings{mnih2016asynchronous, title={Asynchronous methods for deep reinforcement lea

LEI TAI 111 Dec 8, 2022
A bunch of random PyTorch models using PyTorch's C++ frontend

PyTorch Deep Learning Models using the C++ frontend Gettting started Clone the repo 1. https://github.com/mrdvince/pytorchcpp 2. cd fashionmnist or

Vince 0 Jul 13, 2021
PyTorch Autoencoders - Implementing a Variational Autoencoder (VAE) Series in Pytorch.

PyTorch Autoencoders Implementing a Variational Autoencoder (VAE) Series in Pytorch. Inspired by this repository Model List check model paper conferen

Subin An 8 Nov 21, 2022
PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices.

PyTorch-LIT PyTorch-LIT is the Lite Inference Toolkit (LIT) for PyTorch which focuses on easy and fast inference of large models on end-devices. With

Amin Rezaei 157 Dec 11, 2022
A general framework for deep learning experiments under PyTorch based on pytorch-lightning

torchx Torchx is a general framework for deep learning experiments under PyTorch based on pytorch-lightning. TODO list gan-like training wrapper text

Yingtian Liu 6 Mar 17, 2022
A PyTorch Extension: Tools for easy mixed precision and distributed training in Pytorch

Introduction This is a Python package available on PyPI for NVIDIA-maintained utilities to streamline mixed precision and distributed training in Pyto

Artit 'Art' Wangperawong 5 Sep 29, 2021
Pytorch-diffusion - A basic PyTorch implementation of 'Denoising Diffusion Probabilistic Models'

PyTorch implementation of 'Denoising Diffusion Probabilistic Models' This reposi

Arthur Juliani 76 Jan 7, 2023