Code for our ALiBi method for transformer language models.


Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation

This repository contains the code and models for our paper Train Short, Test Long. This file explains how to run our experiments on the WikiText-103 dataset. Read the paper here.

Attention with Linear Biases (ALiBi) is very simple! Instead of adding position embeddings at the bottom of the transformer stack (which we don't) we add a linear bias to each attention score, as depicted in the figure above. The 'm' hyperparam is head-specific and is not learned- it is set at the beginning of training. We have a function that automatically generates these m values given the number of heads in the model.

ALiBi allows the model to be trained on, for example, 1024 tokens, and then do inference on 2048 (or much more) tokens without any finetuning. It's also able to improve performance, even when not extrapolating, in lower resource language modeling settings.

The implementation is very simple.

  1. Remove the position embeddings from the model:
  2. Set up the relative bias matrix, here:
  3. Add the bias matrix to the mask, which is then added in each attention score computation:
  4. (This might not be necessary in other frameworks.) Move the mask computation to before the layer loop, to make the transformer a tiny bit faster:

Thats it!


      title={Train Short, Test Long: Attention with Linear Biases Enables Input Length Extrapolation}, 
      author={Ofir Press and Noah A. Smith and Mike Lewis},


Requirements and Installation

This repository is a fork of the Fairseq repository and so has the same requirements.

Once you've installed the dependencies, you can install this repository by running:

pip install --editable .

Preparing the data

To download and preprocess the data, run:

cd examples/language_model/
cd ../..

python \
    --only-source \
    --trainpref $TEXT/wiki.train.tokens \
    --validpref $TEXT/wiki.valid.tokens \
    --testpref $TEXT/wiki.test.tokens \
    --destdir data-bin/wikitext-103 \
    --workers 20

Training and Inference

To train a language model with attention with linear baises (ALiBi), on input sequences with 512 tokens, run:

python --task language_modeling     data-bin/wikitext-103     --save-dir wt103/  --arch transformer_lm_wiki103     --max-update 286000 --max-lr 1.0 --t-mult 2 --lr-period-updates 270000 --lr-scheduler cosine --lr-shrink 0.75     --warmup-updates 16000 --warmup-init-lr 1e-07 --min-lr 1e-09 --optimizer nag --lr 0.0001 --clip-norm 0.1     --criterion adaptive_loss --seed 1 --fp16     --sample-break-mode none --skip-invalid-size-inputs-valid-test --ddp-backend=no_c10d --no-epoch-checkpoints --tokens-per-sample 512 --max-tokens 9216 --update-freq 1  

For input sequences larger than 512 (and up to 2048) tokens, just change the --tokens-per-sample.

To train the model with inputs of 3072 tokens, the --update-freq parameter must be changed to 3 and the --max-tokens parameter must be reduced to 3072.

Saved Checkpoints

If you'd like to download our trained models on WikiText-103, they are available here:

Input Length Link

Rename the file you downloaded to if you'd like to follow the directions below.


For nonoverlapping evaluation of the validation set, run:

l=1024; fairseq-eval-lm data-bin/wikitext-103/     --path wt103/  --sample-break-mode none --gen-subset valid   --max-sentences 1 --model-overrides "{'max_tokens':$l, 'tokens_per_sample':$l, 'max_target_positions':$l}"  --tokens-per-sample $l --max-tokens $l  --max-target-positions $l  --context-window 0

where l is set to the length of input subsequences during validation (l=1024 in the above example).

  • Explanation regarding multiplying linear biases with q.k^T

    Explanation regarding multiplying linear biases with q.k^T

    From the paper:


    But the README recommends multiplying the linear biases with the mask:

    What I'm struggling to understand is how does that translate to multiplying the linear biases before softmaxing the output of q.k^T.

    Thanks in advance.

    opened by sayakpaul 4
  • How to perform sliding window evaluation?

    How to perform sliding window evaluation?


    Apologies if this was already stated somewhere, but may I know how we could perform the sliding window evaluation as described in the paper Appendix B and Table 7? The current example in README seems to support only non-overlapping evaluation.


    opened by chijames 2
  • ALiBi in Parallel Attention

    ALiBi in Parallel Attention

    Hi @ofirpress ,

    I am working on implementing ALiBi in a Parallel Attention Transformer. I have removed the positional embeddings from the model. I set up a relative alibi bias matrix and calculated the slopes. I then add the alibi attention bias to the causal mask. Unfortunately, I am unable to get the correct number of trainable parameters. Is it possible to take a quick look and see if there is anything noticeably wrong in the code implementation below?


    Function for slopes:

    def get_alibi_slopes(heads):
        def get_slopes_power_of_2(n):
            start = (2 ** (-2 ** -(log2(n) - 3)))
            ratio = start
            return [start*ratio**i for i in range(n)]
        if log2(heads).is_integer():
            return get_slopes_power_of_2(heads)
        closest_power_of_2 = 2 ** floor(log2(heads))
        return get_slopes_power_of_2(closest_power_of_2) + get_slopes_power_of_2(2 * closest_power_of_2)[0::2][:heads-closest_power_of_2]

    Calculate the alibi bias:

    def calc_alibi_bias(seq_len, heads):
        slopes = torch.Tensor(get_alibi_slopes(heads))
        slopes = rearrange(slopes, 'h -> h 1 1')
        bias = rearrange(torch.arange(seq_len), 'j -> 1 1 j')
        return slopes * bias

    Build the Parallel Attention Block:

    class ParallelAttentionBlock(nn.Module):
        def __init__(self, dim, dim_head=64, heads=8, ff_mult=4):
            self.norm = RMSNorm(dim)
            attn_inner_dim = dim_head * heads
            ff_inner_dim = dim * ff_mult
            self.fused_dims = (attn_inner_dim, dim_head, dim_head, (ff_inner_dim * 2))
            self.heads = heads
            self.scale = dim_head**-0.5
            self.fused_attn_ff_proj = nn.Linear(dim, sum(self.fused_dims), bias=False)
            self.attn_out = nn.Linear(attn_inner_dim, dim, bias=False)
            self.ff_out = nn.Sequential(
                nn.Linear(ff_inner_dim, dim, bias=False)
            # for caching causal mask
            self.register_buffer("mask", None, persistent=False)
        def get_mask(self, n, device):
            if self.mask is not None and self.mask.shape[-1] >= n:
                return self.mask[:n, :n]
            mask = torch.triu(torch.ones((n, n), device=device, dtype=torch.bool), 1)
            self.register_buffer("mask", mask, persistent=False)
            return mask
        def forward(self, x):
            n, device, h = x.shape[1], x.device, self.heads
            # pre layernorm
            x = self.norm(x)
            # attention queries, keys, values, and feedforward inner
            q, k, v, ff = self.fused_attn_ff_proj(x).split(self.fused_dims, dim=-1)
            # split heads
            q = rearrange(q, "b n (h d) -> b h n d", h = h)
            # scale
            q = q * self.scale
            # similarity
            sim = einsum("b h i d, b j d -> b h i j", q, k)
            # causal mask
            causal_mask = self.get_mask(n, device)
            sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)
            # alibi bias
            alibi_bias = calc_alibi_bias(n, heads = h)
            attn_bias = repeat(alibi_bias, 'h 1 j -> h i j', i = n)
            attn_bias = attn_bias[..., :n, :n]
            # add the bias matrix to the mask
            sim = sim + attn_bias
            # attention
            attn = sim.softmax(dim=-1)
            # aggregate values
            out = einsum("b h i j, b j d -> b h i d", attn, v)
            # merge heads
            out = rearrange(out, "b h n d -> b n (h d)")
            merge_heads = self.attn_out(out) + self.ff_out(ff)
            return merge_heads

    Any help would be greatly appreciated!

    Thank you,


    opened by conceptofmind 2
  • ALiBi in self-Attention

    ALiBi in self-Attention

    What is your question?

    first of all nice work! and thank you for sharing the code. I noticed that the code use ALiBi in encoder-decoder attention but not in the transformer's self-Attention. Have you tried ALiBi in transformer self-attention? And Is there a reason you didn't use it for the self-attention layer?

    opened by Ldoun 2
  • The numerical value of ALiBi attn_mask

    The numerical value of ALiBi attn_mask


    I really like the elegant idea of ALiBi and thanks a lot for open-sourcing the codebase! I have a small question however, which is the actual numerical value of ALiBi attn_mask applied to the causal self attention matrix. In particular, I printed out the attn_mask in fairseq/modules/ ln 170 and found that the attn_mask is not symmetric wrt the diagonal line, which seems to be different from the description of Figure 3 in the paper.

    For example, I got something like: [[0, -inf, -inf, -inf], [0, 0.0039, -inf, -inf] [0, 0.0039, 0.0078, -inf], [0, 0.0039, 0.0078, 0.0117] ] for one attention head.

    Any help is greatly appreciated! Thanks!

    opened by chijames 1
  • Integration with `transformers`

    Integration with `transformers`

    Amazing work! I'm sure it will open up doors for researchers to think about ways to better extrapolate during inference time.

    I am interested to know if you know of any integrations that use AliBi with transformers from Hugging Face.

    opened by sayakpaul 1
  • Abili on LongformerEncoderDecoder

    Abili on LongformerEncoderDecoder

    Super interesting work and thank you for sharing your code. I am currently working on text summarisation task which requires long input sequences (longer than 16k and that requires large GPU), so I am thinking of applying Alibi on the LongformerEncoderDecoder. Any thoughts on that?

    opened by beaupranisaa 1
  • Modifying ALiBi for Encoder-Attention or Cross-Attention

    Modifying ALiBi for Encoder-Attention or Cross-Attention

    In our paper we only showed results on causal language models, which use causally masked (decoder) self-attention.

    If you'd like to use ALiBi for seq2seq tasks such as translation, speech or T5, or if you'd like to use ALiBi for masked language models such as BERT, some modifications are required.


    Encoder-Attention is the non-masked self-attention that is performed in the encoder of seq2seq models such as translation models or T5. This is also the same kind of attention used in MLM models such as BERT.

    We can't naively copy paste the ALiBi code for these models because it won't work. We use a trick to quickly calculate the bias matrix for causal language modeling, but this bias matrix is only correct for values in or below the main diagonal (since that's all that's used in causal language modeling).

                maxpos = args.tokens_per_sample
                attn_heads = args.encoder_attention_heads  
                context_position = torch.arange(maxpos)[:, None].cuda()
                memory_position = torch.arange(maxpos)[None, :].cuda()
                relative_position = memory_position - context_position 
                relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads, -1,-1)

    This code correctly generates the full bias matrix. Note that the bias matrix is symmetric around the diagonal, since it computes the absolute distance between the query and key (so all distances are positive).

    We're also going to need the code for generating the ALiBi slopes:

                def get_slopes(n):
                    def get_slopes_power_of_2(n):
                        start = (2**(-2**-(math.log2(n)-3)))
                        ratio = start
                        return [start*ratio**i for i in range(n)]
                    if math.log2(n).is_integer():
                        return get_slopes_power_of_2(n)                   #In the paper, we only train models that have 2^a heads for some a. This function has
                    else:                                                 #some good properties that only occur when the input is a power of 2. To maintain that even
                        closest_power_of_2 = 2**math.floor(math.log2(n))  #when the number of heads is not a power of 2, we use this workaround. 
                        return get_slopes_power_of_2(closest_power_of_2) + get_slopes(2*closest_power_of_2)[0::2][:n-closest_power_of_2]

    There are 3 options for implementing encoder-attention ALiBi:

    1. Symmetric: In this option, the bias we assign to query/key pairs that are +N or -N tokens apart will be the same.
                    self.slopes = torch.Tensor(get_slopes(attn_heads)).cuda()*-1
                    self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
                    self.alibi = self.alibi.view(1, attn_heads, maxpos, maxpos)

    Now just pass self.alibi to the attention function and add it after the query*key computation.

    In fairseq for example, the query*key computation is done as such: attn_weights = torch.bmm(q, k.transpose(1, 2)), and then to add the ALiBi values use:

    attn_weights = attn_weights.view(bsz, self.num_heads, tgt_len, src_len)
    attn_weights += alibi[:,:,:tgt_len,:src_len].to(attn_weights)
    attn_weights = attn_weights.view(bsz*self.num_heads, tgt_len, src_len)
    1. Nonsymmetric: Here we are going to make the model nonsymmetric by using the same ALiBi bias as in (1), but this time, we're going to let the first half of the heads only look left and the second half only look right. We'll do this by adding a mask to our attention. Note: This code hasn't been fully tested yet and might contain bugs.
                    self._future_mask_right = torch.triu(utils.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), 1).unsqueeze(0).repeat(attn_heads//2, 1, 1)
                    self._future_mask_left = torch.tril(utils.fill_with_neg_inf(torch.zeros([maxpos, maxpos])), -1).unsqueeze(0).repeat(attn_heads//2, 1, 1)
                    self.nonsym_mask =, self._future_mask_left), dim = 0).unsqueeze(0).cuda()
                    self.slopes = torch.Tensor(get_slopes(attn_heads//2)).cuda()*-1
                    context_position = torch.arange(maxpos)[:, None].cuda()
                    memory_position = torch.arange(maxpos)[None, :].cuda()
                    relative_position = memory_position - context_position
                    relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads//2, -1,-1)
                    self.alibi = self.slopes.unsqueeze(1).unsqueeze(1) * relative_position
                    self.alibi = self.alibi.view(1, attn_heads//2, maxpos, maxpos)
                    self.alibi = self.alibi.repeat(1, 2, 1, 1).cuda()

    Again, as before, add self.alibi to the attn-weights, but this time also add the nonsym_mask tensor. (In fairseq attn_weights += nonsym_mask[:,:,:tgt_len,:src_len].to(attn_weights))

    1. Nonsymmetric with no mask: In this approach, we don't use any masking, but instead we make the positioning non-symmetric by using different ALiBi slopes depending on whether the key is to the left or right of the query. Here, we use learned slopes but you can also do this with non-learned slopes. Note: I haven't tested this code so it might contain bugs!
    slopes_left = torch.nn.parameter.Parameter(torch.Tensor( attn_heads))
    nn.init.normal_(slopes_left, -2,1)
    slopes_right = torch.nn.parameter.Parameter(torch.Tensor( attn_heads))
    nn.init.normal_(slopes_right, -2,1)
    slopes_left = -torch.sigmoid(slopes_left)
    slopes_right = -torch.sigmoid(slopes_right)
    context_position = torch.arange(maxpos)[:, None]
    memory_position = torch.arange(maxpos)[None, :]
    relative_position = memory_position - context_position
    relative_position = torch.abs(relative_position).unsqueeze(0).expand(attn_heads, -1,-1)
    alibi_left = slopes_left.unsqueeze(1).unsqueeze(1) * relative_position
    alibi_right = slopes_right.unsqueeze(1).unsqueeze(1) * relative_position
    self.alibi = torch.triu(alibi_right) + torch.tril(alibi_left)
    1. Check out the variation on option 3 from the LittleBird paper.


    For translation models and models like T5 you will also need to implement cross-attention, which is the attention from the decoder to the encoder. The T5 model uses no positional information in cross-attention and I would recommend doing the same thing.


    NEW: lucidrains has implemented some of the above ideas in the x-transformers repo.

    opened by ofirpress 24
