Code for the Shortformer model, from the paper by Ofir Press, Noah A. Smith and Mike Lewis.



This repository contains the code and the final checkpoint of the Shortformer model. This file explains how to run our experiments on the WikiText-103 dataset. Read the full paper here.

The Shortformer is a combination of two methods:

  1. Staged Training: We first train the model on short input subsequences and then train it on longer ones. This improves both train speed and evaluation perplexity.
  2. Position-Infused Attention + Caching: We cache previously computed subsequence representations and attend to them using Position-Infused Attention. Position-Infused Attention modifies the model so that position embeddings are not added to the word embeddings at the bottom of the network, but instead, they are added to the keys and queries in the attention sublayer (but not to the values). We show that PIA + caching vastly speeds up generation and also improves perplexity.

Staged training requires no modification to the original code. To see how we implemented the Position-Infused Attention and caching, click here. Implementing PIA and caching is very easy, and we've provided detailed comments in the code to explain what how we did it.

If you use this code or results from our paper, please cite:

      title={Shortformer: Better Language Modeling using Shorter Inputs}, 
      author={Ofir Press and Noah A. Smith and Mike Lewis},

Requirements and Installation

This repository is a fork of the Fairseq repository and so has the same requirements.

Once you've installed the dependencies, you can install this repository by running:

pip install --editable .

Preparing the data

To download and preprocess the data, run:

cd examples/language_model/
cd ../..

python \
    --only-source \
    --trainpref $TEXT/wiki.train.tokens \
    --validpref $TEXT/wiki.valid.tokens \
    --testpref $TEXT/wiki.test.tokens \
    --destdir data-bin/wikitext-103 \
    --workers 20

Train/Inference for the different models


Our Shortformer model takes the baseline and adds caching, Position-Infused Attention, and Staged Training.

To train the first stage:

python --task language_modeling     data-bin/wikitext-103     --save-dir checkpoints128e100/     --arch transformer_lm_wiki103     --max-update 140100 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 128 --max-tokens-valid 128 --tokens-from-prev 128 --curriculum 1000 --required-batch-size-multiple 1 --save-interval 100

If your GPUs don't have enough memory to execute that command, you can set --update-freq to 2 and --max-tokens to 4608, or set --update-freq to 3 and --max-tokens to 3072 for running the model with even lower memory constraints. This chunks the batch into 2 or 3 different parts and computes each part seperately (instead of in parallel), so it uses less memory but runs slower.

After that, to train the model with the second (and final) stage:

python --task language_modeling     data-bin/wikitext-103     --save-dir shortformer/ --restore-file checkpoints128e100/     --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 512 --max-tokens-valid 512 --tokens-from-prev 512 --curriculum 1000 --required-batch-size-multiple 1 --no-epoch-checkpoints

Again, you can use the update-freq/max-tokens method from above if you run out of memory.

Saved Checkpoint

If you'd like to download the Shortformer instead of training it, it is available here. Rename that file to if you'd like to follow the directions below.


For nonoverlapping evaluation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path shortformer/  --sample-break-mode none --gen-subset valid   --max-sentences 1

For token-by-token generation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path shortformer/  --sample-break-mode none --gen-subset valid   --max-sentences 1 --sliding-inf 1 --context-window 511 --max-tokens 512

(Note that --context-window is a fairseq command and doesn't have the exact meaning that the term "context window" has in our paper.)

Shortformer (without Staged Training)

Staged training improves the perplexity of the model and makes training faster, so there's no reason not to use it, but if you would like to train the Shortformer without it, the command is

python --task language_modeling     data-bin/wikitext-103     --save-dir shortformer-no-st/      --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 512 --max-tokens-valid 512 --tokens-from-prev 512 --curriculum 1000 --required-batch-size-multiple 1 --no-epoch-checkpoints

For inference, use the same commands as the ones for the Shortformer (above).

Baseline with Staged Training

Our Shortformer model is fast to train and for token-by-token generation, but if speed is not an issue, we can achieve slightly better performance by just applying Staged Training to the Baevski & Auli baseline LM. This model is very slow but achieves the best perplexity.

To train the first stage, download the unmodified fairseq reporsitory and then run:

python --task language_modeling     data-bin/wikitext-103     --save-dir checkpoints-st-128e50/     --arch transformer_lm_wiki103     --max-update 70050 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 9216 --update-freq 1 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 128  --required-batch-size-multiple 1 --save-interval 50

After that, to train the model with the second (and final) stage:

python --task language_modeling     data-bin/wikitext-103     --save-dir st/ --restore-file checkpoints-st-128e50/     --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 3072  --no-epoch-checkpoints


For nonoverlapping evaluation of the validation set, run:

fairseq-eval-lm data-bin/wikitext-103     --path st/  --sample-break-mode none --gen-subset valid   --max-sentences 1

For sliding window evaluation of the validation set, with a stride of 2,560, run:

fairseq-eval-lm data-bin/wikitext-103     --path st/  --sample-break-mode none --gen-subset valid   --max-sentences 1 --context-window 2560

Baseline - Baevski & Auli

To train the baseline, download the unmodified fairseq repository and then run:

python --task language_modeling     data-bin/wikitext-103     --save-dir baseline/  --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --max-tokens 3072 --update-freq 3 --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --tokens-per-sample 3072  --no-epoch-checkpoints


Use the same commands as in the 'Baseline with Staged Training' inference subsection.

You might also like...
An atmospheric growth and evolution model based on the EVo degassing model and FastChem 2.0

EVolve Linking planetary mantles to atmospheric chemistry through volcanism using EVo and FastChem. Overview EVolve is a linked mantle degassing and a

JAX code for the paper
JAX code for the paper "Control-Oriented Model-Based Reinforcement Learning with Implicit Differentiation"

Optimal Model Design for Reinforcement Learning This repository contains JAX code for the paper Control-Oriented Model-Based Reinforcement Learning wi

This repository holds the code for the paper "Deep Conditional Gaussian Mixture Model forConstrained Clustering".

Deep Conditional Gaussian Mixture Model for Constrained Clustering. This repository holds the code for the paper Deep Conditional Gaussian Mixture Mod

Supporting code for the paper
Supporting code for the paper "Dangers of Bayesian Model Averaging under Covariate Shift"

Dangers of Bayesian Model Averaging under Covariate Shift This repository contains the code to reproduce the experiments in the paper Dangers of Bayes

Code for ICCV 2021 paper
Code for ICCV 2021 paper "HuMoR: 3D Human Motion Model for Robust Pose Estimation"

Code for ICCV 2021 paper "HuMoR: 3D Human Motion Model for Robust Pose Estimation"

The code is for the paper
The code is for the paper "A Self-Distillation Embedded Supervised Affinity Attention Model for Few-Shot Segmentation"

SD-AANet The code is for the paper "A Self-Distillation Embedded Supervised Affinity Attention Model for Few-Shot Segmentation" [arxiv] Overview confi

This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code for training a DPR model then continuing training with RAG.

KGI (Knowledge Graph Induction) for slot filling This is the code for our KILT leaderboard submission to the T-REx and zsRE tasks. It includes code fo

Reference implementation of code generation projects from Facebook AI Research. General toolkit to apply machine learning to code, from dataset creation to model training and evaluation. Comes with pretrained models.

This repository is a toolkit to do machine learning for programming languages. It implements tokenization, dataset preprocessing, model training and m

  • Can you share the pretrained models?

    Can you share the pretrained models?

    Nice work and thanks for your code! I want to reproduce the results (inference speed & test ppl) in your paper. It will be much easier if you can provide some pretrained models.

    Thanks again!

    enhancement help wanted 
    opened by huchinlp 3
  • Question about the GPU used for speed evaluation in your paper?

    Question about the GPU used for speed evaluation in your paper?

    Hi, Thank you for releasing this interesting work and the code!

    I want to refer to your paper and the reported speed numbers in Table 6. But I want to make sure I use the same or comparable GPU for adding new numbers. Could you tell what GPU type is used to compute the speed numbers?

    opened by taoleicn 2
  • Can you share the split scripts for TBC dataset?

    Can you share the split scripts for TBC dataset?

    I know TBC is not public available but could you share your split scripts for the dataset? I want to compare with your results but don't know how did you split the dataset. Thanks in advance!

    opened by richardbaihe 1
Ofir Press
PhD student @uwnlp
Ofir Press
Bayesian optimisation library developped by Huawei Noah's Ark Library

Bayesian Optimisation Research This directory contains official implementations for Bayesian optimisation works developped by Huawei R&D, Noah's Ark L

HUAWEI Noah's Ark Lab 395 Dec 30, 2022
Inference code for "StylePeople: A Generative Model of Fullbody Human Avatars" paper. This code is for the part of the paper describing video-based avatars.

NeuralTextures This is repository with inference code for paper "StylePeople: A Generative Model of Fullbody Human Avatars" (CVPR21). This code is for

Visual Understanding Lab @ Samsung AI Center Moscow 18 Oct 6, 2022
This is the official source code for SLATE. We provide the code for the model, the training code, and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

SLATE This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset.

Gautam Singh 66 Dec 26, 2022
In this project we investigate the performance of the SetCon model on realistic video footage. Therefore, we implemented the model in PyTorch and tested the model on two example videos.

Contrastive Learning of Object Representations Supervisor: Prof. Dr. Gemma Roig Institutions: Goethe University CVAI - Computational Vision & Artifici

Dirk Neuhäuser 6 Dec 8, 2022
Step by Step on how to create an vision recognition model using, export the model and run the model in an Azure Function

Step by Step on how to create an vision recognition model using, export the model and run the model in an Azure Function

El Bruno 3 Mar 30, 2022
The LaTeX and Python code for generating the paper, experiments' results and visualizations reported in each paper is available (whenever possible) in the paper's directory

This repository contains the software implementation of most algorithms used or developed in my research. The LaTeX and Python code for generating the

João Fonseca 3 Jan 3, 2023
MBPO (paper: When to trust your model: Model-based policy optimization) in offline RL settings

offline-MBPO This repository contains the code of a version of model-based RL algorithm MBPO, which is modified to perform in offline RL settings Pape

LxzGordon 1 Oct 24, 2021
Capture all information throughout your model's development in a reproducible way and tie results directly to the model code!

Rubicon Purpose Rubicon is a data science tool that captures and stores model training and execution information, like parameters and outcomes, in a r

Capital One 97 Jan 3, 2023
Pre-trained model, code, and materials from the paper "Impact of Adversarial Examples on Deep Learning Models for Biomedical Image Segmentation" (MICCAI 2019).

Adaptive Segmentation Mask Attack This repository contains the implementation of the Adaptive Segmentation Mask Attack (ASMA), a targeted adversarial

Utku Ozbulak 53 Jul 4, 2022