Simple and efficient RevNet-Library with DeepSpeed support

Overview

RevLib

Simple and efficient RevNet-Library with DeepSpeed support

Features

  • Half the constant memory usage and faster than RevNet libraries
  • Less memory than gradient checkpointing (1 * output_size instead of n_layers * output_size)
  • Same speed as activation checkpointing
  • Extensible
  • Trivial code (<100 Lines)

Getting started

Installation

python3 -m pip install revlib

Examples

iRevNet

iRevNet is not only partially reversible but instead a fully-invertible model. The source code looks complex at first glance. It also doesn't use the memory savings it could utilize, as RevNet requires custom AutoGrad functions that are hard to maintain. An iRevNet can be implemented like this using revlib:

import torch
from torch import nn
import revlib

channels = 64
channel_multiplier = 4
depth = 3
classes = 1000


# Create a basic function that's reversibly executed multiple times. (Like f() in ResNet)
def conv(in_channels, out_channels):
    return nn.Conv2d(in_channels, out_channels, (3, 3), padding=1)


def block_conv(in_channels, out_channels):
    return nn.Sequential(conv(in_channels, out_channels),
                         nn.Dropout(0.2),
                         nn.BatchNorm2d(out_channels),
                         nn.ReLU())


def block():
    return nn.Sequential(block_conv(channels, channels * channel_multiplier),
                         block_conv(channels * channel_multiplier, channels),
                         nn.Conv2d(channels, channels, (3, 3), padding=1))


# Create a reversible model. f() is invoked depth-times with different weights.
rev_model = revlib.ReversibleSequential(*[block() for _ in range(depth)])

# Wrap reversible model with non-reversible layers
model = nn.Sequential(conv(3, 2*channels), rev_model, conv(2 * channels, classes))

# Use it like you would a regular PyTorch model
inp = torch.randn((1, 3, 224, 224))
out = model(inp)
out.mean().backward()
assert out.size() == (1, 1000, 224, 224)

MomentumNet

MomentumNet is another recent paper that made significant advancements in the area of memory-efficient networks. They propose to use a momentum stream instead of a second model output as illustrated below: MomentumNetIllustration. Implementing that with revlib requires you to write a custom coupling operation (functional analogue to MemCNN) that merges input and output streams.

import torch
from torch import nn
import revlib

channels = 64
depth = 16
momentum_ema_beta = 0.99


# Compute y2 from x2 and f(x1) by merging x2 and f(x1) in the forward pass.
def momentum_coupling_forward(other_stream: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return other_stream * momentum_ema_beta + fn_out * (1 - momentum_ema_beta)


# Calculate x2 from y2 and f(x1) by manually computing the inverse of momentum_coupling_forward.
def momentum_coupling_inverse(output: torch.Tensor, fn_out: torch.Tensor) -> torch.Tensor:
    return (output - fn_out * (1 - momentum_ema_beta)) / momentum_ema_beta


# Pass in coupling functions which will be used instead of x2 + f(x1) and y2 - f(x1)
rev_model = revlib.ReversibleSequential(*[layer for _ in range(depth)
                                          for layer in [nn.Conv2d(channels, channels, (3, 3), padding=1),
                                                        nn.Identity()]],
                                        coupling_forward=[momentum_coupling_forward, revlib.additive_coupling_forward],
                                        coupling_inverse=[momentum_coupling_inverse, revlib.additive_coupling_inverse])

inp = torch.randn((16, channels * 2, 224, 224))
out = rev_model(inp)
assert out.size() == (16, channels * 2, 224, 224)

Reformer

Reformer uses RevNet with chunking and LSH-attention to efficiently train a transformer. Using revlib, standard implementations, such as lucidrains' Reformer, can be improved upon to use less memory. Below we're still using the basic building blocks from lucidrains' code to have a comparable model.

import torch
from torch import nn
from reformer_pytorch.reformer_pytorch import LSHSelfAttention, Chunk, FeedForward, AbsolutePositionalEmbedding
import revlib


class Reformer(torch.nn.Module):
    def __init__(self, sequence_length: int, features: int, depth: int, heads: int, bucket_size: int = 64,
                 lsh_hash_count: int = 8, ff_chunks: int = 16, input_classes: int = 256, output_classes: int = 256):
        super(Reformer, self).__init__()
        self.token_embd = nn.Embedding(input_classes, features * 2)
        self.pos_embd = AbsolutePositionalEmbedding(features * 2, sequence_length)

        self.core = revlib.ReversibleSequential(*[nn.Sequential(nn.LayerNorm(features), layer) for _ in range(depth)
                                                 for layer in
                                                 [LSHSelfAttention(features, heads, bucket_size, lsh_hash_count),
                                                  Chunk(ff_chunks, FeedForward(features, activation=nn.GELU), 
                                                        along_dim=-2)]],
                                                split_dim=-1)
        self.out_norm = nn.LayerNorm(features * 2)
        self.out_linear = nn.Linear(features * 2, output_classes)

    def forward(self, inp: torch.Tensor) -> torch.Tensor:
        return self.out_linear(self.out_norm(self.core(self.token_embd(inp) + self.pos_embd(inp))))


sequence = 1024
classes = 16
model = Reformer(sequence, 256, 6, 8, output_classes=classes)
out = model(torch.ones((16, sequence), dtype=torch.long))
assert out.size() == (16, sequence, classes)

Explanation

Most other RevNet libraries, such as MemCNN and Revtorch calculate both f() and g() in one go, to create one large computation. RevLib, on the other hand, brings Mesh TensorFlow's "reversible half residual and swap" to PyTorch. reversible_half_residual_and_swap computes only one of f() and g() and swaps the inputs and gradients. This way, the library only has to store one output as it can recover the other output during the backward pass.
Following Mesh TensorFlow's example, revlib also uses separate x1 and x2 tensors instead of concatenating and splitting at every step to reduce the cost of memory-bound operations.

RevNet's memory consumption doesn't scale with its depth, so it's significantly more memory-efficient for deep models. One problem in most implementations was that two tensors needed to be stored in the output, quadrupling the required memory. The high memory consumption rendered RevNet nearly useless for small networks, such as BERT, with its six layers.
RevLib works around this problem by storing only one output and two inputs for each forward pass, giving a model as small as BERT a >2x improvement!

Ignoring the dual-path structure of a RevNet, it usually used to be much slower than gradient checkpointing. However, RevLib uses minimal coupling functions and has no overhead between Sequence items, allowing it to train as fast as a comparable model with gradient checkpointing.

Comments
  • Compute input from output?

    Compute input from output?

    Hey @ClashLuke, nice code! You point to iRevNet as a comparison, and explain why this implementation is more memory efficient. However, iRevNet includes a method to, given the pre-pooler feature set for the CNN, directly invert the features and compute the input that generated those features. I don't see an obvious way to do that using your library. Is that something that can be done and do you intend to add a method to do that? I'm particularly interested in computing the inverse of a transformer given the output hidden states, which is something that iRevNet does not support as far as I can tell since it's CNN-specific, but if it's implementable within your framework as seems likely I would be very interested in leveraging it for my research.

    Thank you! :)

    opened by isaacrob 7
  • Suggestions for constructing a ResNet with the revlib

    Suggestions for constructing a ResNet with the revlib

    Hi

    I would like to use this library to build a ResNet20 model, I've tried several times but I still have the mismatched dimension error. My model is shown as follows:

    import torch
    import torch.nn as nn
    import torch.nn.functional as F
    import torch.nn.init as init
    
    from torch.nn.modules.batchnorm import _BatchNorm
    from torch.nn.modules.instancenorm import _InstanceNorm
    
    hidden_size = [16, 32, 64]
    
    class View(nn.Module):
        def forward(self, x):
            batch_size = x.size(0)
            return x.view(batch_size, -1)
    
    class LambdaLayer(nn.Module):
        def __init__(self, lambd):
            super(LambdaLayer, self).__init__()
            self.lambd = lambd
    
        def forward(self, x):
            return self.lambd(x)
    
    
    class BasicBlock(nn.Module):
        expansion = 1
    
        def __init__(self, in_planes, planes, stride, norm_layer, conv_layer, option='A'):
            super(BasicBlock, self).__init__()
            
            self.bn1 = norm_layer(in_planes)
            self.conv1 = conv_layer(in_planes, planes, kernel_size=3, stride=stride, padding=1, bias=False)
            self.bn2 = norm_layer(planes)
            self.conv2 = conv_layer(planes, planes, kernel_size=3, stride=1, padding=1, bias=False)
    
            self.shortcut = nn.Sequential()
            if stride != 1 or in_planes != self.expansion * planes:
                if option == 'A':
                    """
                    For CIFAR10 ResNet paper uses option A.
                    """
                    self.shortcut = LambdaLayer(lambda x:
                                                F.pad(x[:, :, ::2, ::2], (0, 0, 0, 0, planes//4, planes//4), "constant", 0))
                elif option == 'B':
                    self.shortcut = conv_layer(in_planes, self.expansion * planes, kernel_size=1, stride=stride, bias=False)
    
        def forward(self, x):
            out = F.relu(self.bn1(x))
            shortcut = self.shortcut(out) if hasattr(self, 'shortcut') else x
            print(1, shortcut.size())
            out = self.conv1(out)
            out = self.conv2(F.relu(self.bn2(out)))
            print(2, out.size())
            out += shortcut
            return out
    
    class ResNet(nn.Module):
    
        def __init__(self, hidden_size, block, num_blocks, num_classes=10, bn_type='bn',
                     share_affine=False, track_running_stats=True):
            super(ResNet, self).__init__()
            self.in_planes = 16
    
            self.bn_type = bn_type
            if bn_type == 'bn':
                norm_layer = lambda n_ch: nn.BatchNorm2d(n_ch, track_running_stats=track_running_stats)
            elif bn_type == 'gn':
                norm_layer = lambda n_ch: nn.GroupNorm(4, n_ch) # 3 can be changed -- # of groups
            else:
                raise RuntimeError(f"Not support bn_type={bn_type}")
            conv_layer = nn.Conv2d
            first = conv_layer(3, hidden_size[0], kernel_size=3, stride=1, padding=1, bias=False)
            layer1 = self._make_layer(block, hidden_size[0], num_blocks[0], stride=1,
                                           norm_layer=norm_layer, conv_layer=conv_layer)
            layer2 = self._make_layer(block, hidden_size[1], num_blocks[1], stride=2,
                                           norm_layer=norm_layer, conv_layer=conv_layer)
            layer3 = self._make_layer(block, hidden_size[2], num_blocks[2], stride=2,
                                           norm_layer=norm_layer, conv_layer=conv_layer)
            
            self.rev_layers = revlib.ReversibleSequential(*[layer1, layer2, layer3])
            
            norm = norm_layer(hidden_size[2] * block.expansion)
            linear = nn.Linear(hidden_size[2] * block.expansion, num_classes)
            
            self.full_model = nn.Sequential(first, self.rev_layers, nn.ReLU(), norm, \
                                            nn.AdaptiveAvgPool2d((None, 1)), View(), linear)
    
        def _make_layer(self, block, planes, num_blocks, stride, norm_layer, conv_layer):
            strides = [stride] + [1] * (num_blocks - 1)
            layers = []
            for stride in strides:
                layers.append(block(self.in_planes, planes, stride, norm_layer, conv_layer))
                self.in_planes = planes * block.expansion
            return nn.Sequential(*layers)
    
        def forward(self, x):
            out = self.full_model(x)
            return out
    
    def init_param(m):
        """Special init for ResNet"""
        if isinstance(m, (_BatchNorm, _InstanceNorm)):
            m.weight.data.fill_(1)
            m.bias.data.zero_()
        elif isinstance(m, nn.Linear):
            m.bias.data.zero_()
        return m
    
    def resnet20(**kwargs):
        model = ResNet(hidden_size, BasicBlock, [3,3,3], **kwargs)
        model.apply(init_param)
        return model
    

    I've tried to modify self.in_planes = 8 and hidden_size = [8, 16, 32], respectively, but it still does not work. Could you provide any hints? Is it possible to build a model in a forward way instead of wrapping reversible model with non-reversible layers like model = nn.Sequential(conv, rev_layer, conv)? I appreciate your help.

    opened by taokz 5
  • TorchSummary for revlib

    TorchSummary for revlib

    Hi!

    Thank you for your great work. I was wondering how can I mode's information such as number of parameters, estimated total model size via torchsummary library or other possible tools.

    I've received RuntimeError: Given groups=1, weight of size [256, 64, 3, 3], expected input[2, 32, 224, 224] to have 64 channels, but got 32 channels instead when I run summary(rev_model, input_size=(64, 224, 224)). I understand that the input in the revnet is divided into two parts in terms of channels (64/2=32). Any suggestion? I appreciate your help!

    opened by taokz 4
  • convolution network, striding

    convolution network, striding

    If I need to stride the input, I'm wondering how you can recommend how to use it.

    class Stride(nn.Module):
        def forward(self, input):
            return input[[slice(None)]*2 + [slice(0, None, 2) for _ in range(2, input.ndim)]]
    Stride()(torch.randn(1, 1, 11, 16)).shape # 1, 1, 6, 8
    
    opened by klae01 1
  • Torch1.10

    Torch1.10

    PyTorch 1.10 will add torch.autograd.graph.saved_tensors_hooks, which I used to significantly improve code style without hurting performance or changing the API.

    I used the improved code to add a memory_savings=True switch to ReversibleSequential without further complicating the code. In addition, the flag allows for easier debugging and tests.

    This branch also adds tests regarding the memory efficiency and correctness of gradients.
    All tests pass on both branches.

    opened by ClashLuke 1
Owner
Lucas Nestler
German ai researcher
Lucas Nestler
Princeton NLP's pre-training library based on fairseq with DeepSpeed kernel integration 🚃

This repository provides a library for efficient training of masked language models (MLM), built with fairseq. We fork fairseq to give researchers mor

Princeton Natural Language Processing 92 Dec 27, 2022
Guide: Finetune GPT2-XL (1.5 Billion Parameters) and GPT-NEO (2.7 B) on a single 16 GB VRAM V100 Google Cloud instance with Huggingface Transformers using DeepSpeed

Guide: Finetune GPT2-XL (1.5 Billion Parameters) and GPT-NEO (2.7 Billion Parameters) on a single 16 GB VRAM V100 Google Cloud instance with Huggingfa

null 289 Jan 6, 2023
Train 🤗transformers with DeepSpeed: ZeRO-2, ZeRO-3

Fork from https://github.com/huggingface/transformers/tree/86d5fb0b360e68de46d40265e7c707fe68c8015b/examples/pytorch/language-modeling at 2021.05.17.

Junbum Lee 12 Oct 26, 2022
Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

Transformers4Rec is a flexible and efficient library for sequential and session-based recommendation, available for both PyTorch and Tensorflow.

null 730 Jan 9, 2023
Tevatron is a simple and efficient toolkit for training and running dense retrievers with deep language models.

Tevatron Tevatron is a simple and efficient toolkit for training and running dense retrievers with deep language models. The toolkit has a modularized

texttron 193 Jan 4, 2023
An ultra fast tiny model for lane detection, using onnx_parser, TensorRTAPI, torch2trt to accelerate. our model support for int8, dynamic input and profiling. (Nvidia-Alibaba-TensoRT-hackathon2021)

Ultra_Fast_Lane_Detection_TensorRT An ultra fast tiny model for lane detection, using onnx_parser, TensorRTAPI to accelerate. our model support for in

steven.yan 121 Dec 27, 2022
An easy to use, user-friendly and efficient code for extracting OpenAI CLIP (Global/Grid) features from image and text respectively.

Extracting OpenAI CLIP (Global/Grid) Features from Image and Text This repo aims at providing an easy to use and efficient code for extracting image &

Jianjie(JJ) Luo 13 Jan 6, 2023
Py65 65816 - Add support for the 65C816 to py65

Add support for the 65C816 to py65 Py65 (https://github.com/mnaberez/py65) is a

null 4 Jan 4, 2023
Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks

Official PyTorch code for ClipBERT, an efficient framework for end-to-end learning on image-text and video-text tasks. It takes raw videos/images + text as inputs, and outputs task predictions. ClipBERT is designed based on 2D CNNs and transformers, and uses a sparse sampling strategy to enable efficient end-to-end video-and-language learning.

Jie Lei 雷杰 612 Jan 4, 2023
🤗 The largest hub of ready-to-use NLP datasets for ML models with fast, easy-to-use and efficient data manipulation tools

?? The largest hub of ready-to-use NLP datasets for ML models with fast, easy-to-use and efficient data manipulation tools

Hugging Face 15k Jan 2, 2023
This repository contains all the source code that is needed for the project : An Efficient Pipeline For Bloom’s Taxonomy Using Natural Language Processing and Deep Learning

Pipeline For NLP with Bloom's Taxonomy Using Improved Question Classification and Question Generation using Deep Learning This repository contains all

Rohan Mathur 9 Jul 17, 2021
Code to reprudece NeurIPS paper: Accelerated Sparse Neural Training: A Provable and Efficient Method to Find N:M Transposable Masks

Accelerated Sparse Neural Training: A Provable and Efficient Method to FindN:M Transposable Masks Recently, researchers proposed pruning deep neural n

itay hubara 4 Feb 23, 2022
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Dec 30, 2022
Code for the Findings of NAACL 2022(Long Paper): AdapterBias: Parameter-efficient Token-dependent Representation Shift for Adapters in NLP Tasks

AdapterBias: Parameter-efficient Token-dependent Representation Shift for Adapters in NLP Tasks arXiv link: upcoming To be published in Findings of NA

Allen 16 Nov 12, 2022
Visual Automata is a Python 3 library built as a wrapper for Caleb Evans' Automata library to add more visualization features.

Visual Automata Copyright 2021 Lewi Lie Uberg Released under the MIT license Visual Automata is a Python 3 library built as a wrapper for Caleb Evans'

Lewi Uberg 55 Nov 17, 2022
:id: A python library for accurate and scalable fuzzy matching, record deduplication and entity-resolution.

Dedupe Python Library dedupe is a python library that uses machine learning to perform fuzzy matching, deduplication and entity resolution quickly on

Dedupe.io 3.6k Jan 2, 2023
Client library to download and publish models and other files on the huggingface.co hub

huggingface_hub Client library to download and publish models and other files on the huggingface.co hub Do you have an open source ML library? We're l

Hugging Face 644 Jan 1, 2023