[ICLR'19] Trellis Networks for Sequence Modeling

Overview

TrellisNet for Sequence Modeling

PWC PWC

This repository contains the experiments done in paper Trellis Networks for Sequence Modeling by Shaojie Bai, J. Zico Kolter and Vladlen Koltun.

On the one hand, a trellis network is a temporal convolutional network with special structure, characterized by weight tying across depth and direct injection of the input into deep layers. On the other hand, we show that truncated recurrent networks are equivalent to trellis networks with special sparsity structure in their weight matrices. Thus trellis networks with general weight matrices generalize truncated recurrent networks. This allows trellis networks to serve as bridge between recurrent and convolutional architectures, benefitting from algorithmic and architectural techniques developed in either context. We leverage these relationships to design high-performing trellis networks that absorb ideas from both architectural families. Experiments demonstrate that trellis networks outperform the current state of the art on a variety of challenging benchmarks, including word-level language modeling on Penn Treebank and WikiText-103 (UPDATE: recently surpassed by Transformer-XL), character-level language modeling on Penn Treebank, and stress tests designed to evaluate long-term memory retention.

Our experiments were done in PyTorch. If you find our work, or this repository helpful, please consider citing our work:

@inproceedings{bai2018trellis,
  author    = {Shaojie Bai and J. Zico Kolter and Vladlen Koltun},
  title     = {Trellis Networks for Sequence Modeling},
  booktitle = {International Conference on Learning Representations (ICLR)},
  year      = {2019},
}

Datasets

The code should be directly runnable with PyTorch 1.0.0 or above. This repository contains the training script for the following tasks:

  • Sequential MNIST handwritten digit classification
  • Permuted Sequential MNIST that randomly permutes the pixel order in sequential MNIST
  • Sequential CIFAR-10 classification (more challenging, due to more intra-class variations, channel complexities and larger images)
  • Penn Treebank (PTB) word-level language modeling (with and without the mixture of softmax); vocabulary size 10K
  • Wikitext-103 (WT103) large-scale word-level language modeling; vocabulary size 268K
  • Penn Treebank medium-scale character-level language modeling

Note that these tasks are on very different scales, with unique properties that challenge sequence models in different ways. For example, word-level PTB is a small dataset that a typical model easily overfits, so judicious regularization is essential. WT103 is a hundred times larger, with less danger of overfitting, but with a vocabulary size of 268K that makes training more challenging (due to large embedding size).

Pre-trained Model(s)

We provide some reasonably good pre-trained weights here so that the users don't need to train from scratch. We'll update the table from time to time. (Note: if you train from scratch using different seeds, it's likely you will get better results :-))

Description Task Dataset Model
TrellisNet-LM Word-Level Language Modeling Penn Treebank (PTB) download (.pkl)
TrellisNet-LM Character-Level Language Modeling Penn Treebank (PTB) download (.pkl)

To use the pre-trained weights, use the flag --load_weight [.pkl PATH] when starting the training script (e.g., you can just use the default arg parameters). You can use the flag --eval turn on the evaluation mode only.

Usage

All tasks share the same underlying TrellisNet model, which is in file trellisnet.py (and the eventual models, including components like embedding layer, are in model.py). As discussed in the paper, TrellisNet is able to benefit significantly from techniques developed originally for RNNs as well as temporal convolutional networks (TCNs). Some of these techniques are also included in this repository. Each task is organized in the following structure:

[TASK_NAME] /
    data/
    logs/
    [TASK_NAME].py
    model.py
    utils.py
    data.py

where [TASK_NAME].py is the training script for the task (with argument flags; use -h to see the details).

Comments
  • Question about pre-computed

    Question about pre-computed

    Hello,

    I was reading your paper and having a question about https://github.com/locuslab/trellisnet/blob/ec1de0a5ee09ef5a4c5bca4c83456dec8cbdf4c8/TrellisNet/trellisnet.py#L55,I don't think it can reduce the computation,How did it work? Is it true that I can only find one cycle and only reduce one calculation that in the layer 1?

    opened by Sample-design-alt 8
  • Possible mistake in LSTM cell

    Possible mistake in LSTM cell

    Hello,

    I was reading your paper and noticed that code in trellisnet.py (https://github.com/locuslab/trellisnet/blob/master/TrellisNet/trellisnet.py ) lines 124-129, does not correspond to the formula in the paper, section 5.1 formula(12). Could you clarify if this is true or I am wrong and don’t understand something.

    Thank you

    opened by JurijsNazarovs 4
  • GPU specifications for character-level PTB

    GPU specifications for character-level PTB

    Hi, I'm wondering what GPU you used to train the character-level PTB model? Just so I know how much vram the GPU needs to train given your hyper-parameters.

    Thanks

    opened by arvieFrydenlund 3
  • question about dilation

    question about dilation

    Hi,

    I have a question about the implementation of dilation in class TrellisNet In the code trellisnet.py

    # Feed-forward layers
    for i in range(0, self.nlevels):
        d = dilation_cycle[i % len(dilation_cycle)]
        Z = self.step(Z, dilation=d, hc=hc)
        if aux and i % self.aux_frequency == (self.aux_frequency-1):
            aux_outs.append(Z[:, -nout:].unsqueeze(0))
    

    I found you cycling the dilation parameter. What is the reason behind the cycling?

    Thanks

    opened by junhyk 2
  • Questions to port model in Keras

    Questions to port model in Keras

    I'm trying to port TrellisNet to Keras and got some questions.

    1. Why do you use custom weight normalization https://github.com/locuslab/trellisnet/blob/master/TrellisNet/optimizations.py#L190 instead of common once https://pytorch.org/docs/stable/_modules/torch/nn/utils/weight_norm.html ? What is wrong with PyTorch core implementation?

    2. Why do you manually recompute weights norm in upper layer
      https://github.com/locuslab/trellisnet/blob/master/TrellisNet/trellisnet.py#L145 instead of putting it inside forward method of WeightNorm itself?

    3. Why do you use dialation + device key here https://github.com/locuslab/trellisnet/blob/master/TrellisNet/trellisnet.py#L55 instead of just dialation

    4. It seems that VariationalDropout here https://github.com/locuslab/trellisnet/blob/master/TrellisNet/optimizations.py#L119 is not a true VariationalDropout but simply https://www.tensorflow.org/versions/r2.0/api_docs/python/tf/keras/layers/SpatialDropout1D

    opened by shkarupa-alex 1
  • "index out of bounds"`

    Weight normalization applied /word_WT103/splitcross.py:130: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. softmaxed_all_head_res = torch.nn.functional.log_softmax(all_head_res) /TrellisNet/word_WT103/splitcross.py:67: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. tail_entropy = torch.nn.functional.log_softmax(tail_res) /TrellisNet/word_WT103/splitcross.py:164: UserWarning: Implicit dimension choice for log_softmax has been deprecated. Change the call to include dim=X as an argument. tail_entropy = torch.gather(torch.nn.functional.log_softmax(tail_res), dim=1, index=indices).squeeze() /opt/conda/conda-bld/pytorch_1614378124864/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:115: operator(): block: [2,0,0], thread: [32,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1614378124864/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:115: operator(): block: [2,0,0], thread: [34,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed. /opt/conda/conda-bld/pytorch_1614378124864/work/aten/src/ATen/native/cuda/ScatterGatherKernel.cu:115: operator(): block: [2,0,0], thread: [35,0,0] Assertion idx_dim >= 0 && idx_dim < index_size && "index out of bounds" failed.

    opened by Chen-PengF 1
Owner
CMU Locus Lab
Zico Kolter's Research Group
CMU Locus Lab
Sequence modeling benchmarks and temporal convolutional networks

Sequence Modeling Benchmarks and Temporal Convolutional Networks (TCN) This repository contains the experiments done in the work An Empirical Evaluati

CMU Locus Lab 3.5k Jan 1, 2023
Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Segmentation Transformer Implementation of Segmentation Transformer in PyTorch, a new model to achieve SOTA in semantic segmentation while using trans

Abhay Gupta 161 Dec 8, 2022
Implementation of SETR model, Original paper: Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.

SETR - Pytorch Since the original paper (Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers.) has no official

zhaohu xing 112 Dec 16, 2022
Understanding and Improving Encoder Layer Fusion in Sequence-to-Sequence Learning (ICLR 2021)

Understanding and Improving Encoder Layer Fusion in Sequence-to-Sequence Learning (ICLR 2021) Citation Please cite as: @inproceedings{liu2020understan

Sunbow Liu 22 Nov 25, 2022
[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

[CVPR 2021] Rethinking Semantic Segmentation from a Sequence-to-Sequence Perspective with Transformers

Fudan Zhang Vision Group 897 Jan 5, 2023
Sequence to Sequence Models with PyTorch

Sequence to Sequence models with PyTorch This repository contains implementations of Sequence to Sequence (Seq2Seq) models in PyTorch At present it ha

Sandeep Subramanian 708 Dec 19, 2022
Sequence-to-Sequence learning using PyTorch

Seq2Seq in PyTorch This is a complete suite for training sequence-to-sequence models in PyTorch. It consists of several models and code to both train

Elad Hoffer 514 Nov 17, 2022
An implementation of a sequence to sequence neural network using an encoder-decoder

Keras implementation of a sequence to sequence model for time series prediction using an encoder-decoder architecture. I created this post to share a

Luke Tonin 195 Dec 17, 2022
Sequence lineage information extracted from RKI sequence data repo

Pango lineage information for German SARS-CoV-2 sequences This repository contains a join of the metadata and pango lineage tables of all German SARS-

Cornelius Roemer 24 Oct 26, 2022
Official repository of OFA. Paper: Unifying Architectures, Tasks, and Modalities Through a Simple Sequence-to-Sequence Learning Framework

Paper | Blog OFA is a unified multimodal pretrained model that unifies modalities (i.e., cross-modality, vision, language) and tasks (e.g., image gene

OFA Sys 1.4k Jan 8, 2023
Official codebase for Decision Transformer: Reinforcement Learning via Sequence Modeling.

Decision Transformer Lili Chen*, Kevin Lu*, Aravind Rajeswaran, Kimin Lee, Aditya Grover, Michael Laskin, Pieter Abbeel, Aravind Srinivas†, and Igor M

Kevin Lu 1.4k Jan 7, 2023
Code for the paper "Reinforcement Learning as One Big Sequence Modeling Problem"

Trajectory Transformer Code release for Reinforcement Learning as One Big Sequence Modeling Problem. Installation All python dependencies are in envir

Michael Janner 269 Jan 5, 2023
Code for the paper "Offline Reinforcement Learning as One Big Sequence Modeling Problem"

Trajectory Transformer Code release for Offline Reinforcement Learning as One Big Sequence Modeling Problem. Installation All python dependencies are

Michael Janner 266 Dec 27, 2022
Clean and readable code for Decision Transformer: Reinforcement Learning via Sequence Modeling

Minimal implementation of Decision Transformer: Reinforcement Learning via Sequence Modeling in PyTorch for mujoco control tasks in OpenAI gym

Nikhil Barhate 104 Jan 6, 2023
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
Selene is a Python library and command line interface for training deep neural networks from biological sequence data such as genomes.

Selene is a Python library and command line interface for training deep neural networks from biological sequence data such as genomes.

Troyanskaya Laboratory 323 Jan 1, 2023
Lingvo is a framework for building neural networks in Tensorflow, particularly sequence models.

Lingvo is a framework for building neural networks in Tensorflow, particularly sequence models.

null 2.7k Jan 5, 2023
A PyTorch Implementation of Gated Graph Sequence Neural Networks (GGNN)

A PyTorch Implementation of GGNN This is a PyTorch implementation of the Gated Graph Sequence Neural Networks (GGNN) as described in the paper Gated G

Ching-Yao Chuang 427 Dec 13, 2022
Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Networks

CyGNet This repository reproduces the AAAI'21 paper “Learning from History: Modeling Temporal Knowledge Graphs with Sequential Copy-Generation Network

CunchaoZ 89 Jan 3, 2023