Fully featured implementation of Routing Transformer

Overview

Routing Transformer

PyPI version

A fully featured implementation of Routing Transformer. The paper proposes using k-means to route similar queries / keys into the same cluster for attention.

Open In Colab 131k tokens

Install

$ pip install routing_transformer

Usage

A simple language model

import torch
from routing_transformer import RoutingTransformerLM

model = RoutingTransformerLM(
    num_tokens = 20000,
    dim = 512,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    causal = True,           # auto-regressive or not
    emb_dim = 128,           # embedding factorization, from Albert
    weight_tie = False,      # weight tie layers, from Albert
    tie_embedding = False,   # multiply final embeddings with token weights for logits
    dim_head = 64,           # be able to fix the dimension of each head, making it independent of the embedding dimension and the number of heads
    attn_dropout = 0.1,      # dropout after attention
    attn_layer_dropout = 0., # dropout after self attention layer
    ff_dropout = 0.1,        # feedforward dropout
    layer_dropout = 0.,      # layer dropout
    window_size = 128,       # target window size of each cluster
    n_local_attn_heads = 4,  # number of local attention heads
    reversible = True,       # reversible networks for memory savings, from Reformer paper
    ff_chunks = 10,          # feed forward chunking, from Reformer paper
    ff_glu = True,           # use GLU variant in feedforward
    pkm_layers = (4, 7),     # specify layers to use product key memory. paper shows 1 or 2 modules near the middle of the transformer is best
    pkm_num_keys = 128,      # defaults to 128, but can be increased to 256 or 512 as memory allows
    moe_layers = (3, 6),     # specify which layers to use mixture of experts
    moe_num_experts = 4,     # number of experts in the mixture of experts layer, defaults to 4. increase for adding more parameters to model
    moe_loss_coef = 1e-2,    # the weight for the auxiliary loss in mixture of experts to keep expert usage balanced
    num_mem_kv = 8,          # number of memory key/values to append to each cluster of each head, from the 'All-Attention' paper. defaults to 1 in the causal case for unshared QK to work
    use_scale_norm = False,  # use scale norm, simplified normalization from 'Transformers without Tears' paper
    use_rezero = False,      # use Rezero with no normalization
    shift_tokens = True      # shift tokens by one along sequence dimension, for a slight improvement in convergence
).cuda()

x = torch.randint(0, 20000, (1, 8192)).long().cuda()
input_mask = torch.ones_like(x).bool().cuda()

y, aux_loss = model(x, input_mask = input_mask) # (1, 8192, 20000)
aux_loss.backward() # add auxiliary loss to main loss before backprop

A simple transformer

import torch
from routing_transformer import RoutingTransformer

model = RoutingTransformer(
    dim = 512,
    heads = 8,
    depth = 12,
    max_seq_len = 8192,
    window_size = 128,
    n_local_attn_heads = 4
).cuda()

x = torch.randn(1, 8192, 512).cuda()
input_mask = torch.ones(1, 8192).bool().cuda()

y, aux_loss = model(x, input_mask = input_mask) # (1, 8192, 512)
aux_loss.backward() # add auxiliary loss to main loss before backprop

Encoder Decoder

To use a full encoder, decoder, simply import the RoutingTransformerEncDec class. Save for the dim keyword, all other keywords will be either prepended with enc_ or dec_ for the encoder and decoder RoutingTransformerLM class respectively.

import torch
from routing_transformer import RoutingTransformerEncDec

model = RoutingTransformerEncDec(
    dim=512,
    enc_num_tokens = 20000,
    enc_depth = 4,
    enc_heads = 8,
    enc_max_seq_len = 4096,
    enc_window_size = 128,
    dec_num_tokens = 20000,
    dec_depth = 4,
    dec_heads = 8,
    dec_max_seq_len = 4096,
    dec_window_size = 128,
    dec_reversible = True
).cuda()

src = torch.randint(0, 20000, (1, 4096)).cuda()
tgt = torch.randint(0, 20000, (1, 4096)).cuda()
src_mask = torch.ones_like(src).bool().cuda()
tgt_mask = torch.ones_like(tgt).bool().cuda()

loss, aux_loss = model(src, tgt, enc_input_mask = src_mask, dec_input_mask = tgt_mask, return_loss = True, randomly_truncate_sequence = True)
loss.backward()
aux_loss.backward()

# do your training, then to sample up to 2048 tokens based on the source sequence
src = torch.randint(0, 20000, (1, 4096)).cuda()
start_tokens = torch.ones(1, 1).long().cuda() # assume starting token is 1

sample = model.generate(src, start_tokens, seq_len = 2048, eos_token = 2) # (1, <= 2048, 20000)

Product Key Memory

To see the benefits of using PKM, the learning rate of the values must be set higher than the rest of the parameters. (Recommended to be 1e-2)

You can follow the instructions here to set it correctly https://github.com/lucidrains/product-key-memory#learning-rates

Kmeans Hyperparameters

  1. kmeans_ema_decay = {defaults to 0.999}

This is the exponential moving average decay for updating the k-means. The lower this is, the faster the means will adjust, but at the cost of stability.

  1. commitment_factor = {defaults to 1e-4}

The weight of the auxiliary loss that encourages tokens to get closer (commit) to the k-mean centroids that were chosen for them.

Updating kmeans manually

The following instructions will allow you to update the kmeans manually. By default the kmeans are updated automatically on every backward pass.

import torch
from routing_transformer import RoutingTransformerLM, AutoregressiveWrapper

model = RoutingTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 6,
    window_size = 256,
    max_seq_len = 8192,
    causal = True,
    _register_kmeans_update = False # set to False to disable auto-updating
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True)
loss.backward()

# update kmeans with this call
model.update_kmeans()

Issues

This architecture has trouble generalizing to shorter sequence lengths when decoding tokens from 1 -> maximum sequence length. The simplest and surest solution is to randomly truncate the sequence during training. This helps the network and the kmeans generalize to variable number of tokens, at the cost of prolonged training.

If you are priming the network with the full sequence length at start, then you will not face this problem, and you can skip this training procedure.

import torch
from routing_transformer import RoutingTransformerLM, AutoregressiveWrapper

model = RoutingTransformerLM(
    num_tokens = 20000,
    dim = 1024,
    heads = 8,
    depth = 12,
    window_size = 256,
    max_seq_len = 8192,
    causal = True
)

model = AutoregressiveWrapper(model)

x = torch.randint(0, 20000, (1, 8192))
loss = model(x, return_loss = True, randomly_truncate_sequence = True) # (1, 8192, 20000)

Appreciation

Special thanks to Aran Komatsuzaki for bootstrapping the initial implementation in Pytorch that evolved into this library.

Citation

@misc{roy*2020efficient,
    title   = {Efficient Content-Based Sparse Attention with Routing Transformers},
    author  = {Aurko Roy* and Mohammad Taghi Saffar* and David Grangier and Ashish Vaswani},
    year    = {2020},
    url     = {https://arxiv.org/pdf/2003.05997.pdf}
}
@misc{shazeer2020glu,
    title   = {GLU Variants Improve Transformer},
    author  = {Noam Shazeer},
    year    = {2020},
    url     = {https://arxiv.org/abs/2002.05202}    
}
@inproceedings{kitaev2020reformer,
    title       = {Reformer: The Efficient Transformer},
    author      = {Nikita Kitaev and Lukasz Kaiser and Anselm Levskaya},
    booktitle   = {International Conference on Learning Representations},
    year        = {2020},
    url         = {https://openreview.net/forum?id=rkgNKkHtvB}
}
@inproceedings{fan2020reducing,
    title     ={Reducing Transformer Depth on Demand with Structured Dropout},
    author    ={Angela Fan and Edouard Grave and Armand Joulin},
    booktitle ={International Conference on Learning Representations},
    year      ={2020},
    url       ={https://openreview.net/forum?id=SylO2yStDr}
}
@misc{lan2019albert,
    title       = {ALBERT: A Lite BERT for Self-supervised Learning of Language Representations},
    author      = {Zhenzhong Lan and Mingda Chen and Sebastian Goodman and Kevin Gimpel and Piyush Sharma and Radu Soricut},
    year        = {2019},
    url         = {https://arxiv.org/abs/1909.11942}
}
@misc{lample2019large,
    title   = {Large Memory Layers with Product Keys},
    author  = {Guillaume Lample and Alexandre Sablayrolles and Marc'Aurelio Ranzato and Ludovic Denoyer and Hervé Jégou},
    year    = {2019},
    eprint  = {1907.05242},
    archivePrefix = {arXiv}
}
@article{DBLP:journals/corr/abs-1907-01470,
    author    = {Sainbayar Sukhbaatar and
               Edouard Grave and
               Guillaume Lample and
               Herv{\'{e}} J{\'{e}}gou and
               Armand Joulin},
    title     = {Augmenting Self-attention with Persistent Memory},
    journal   = {CoRR},
    volume    = {abs/1907.01470},
    year      = {2019},
    url       = {http://arxiv.org/abs/1907.01470}
}
@misc{bhojanapalli2020lowrank,
    title   = {Low-Rank Bottleneck in Multi-head Attention Models},
    author  = {Srinadh Bhojanapalli and Chulhee Yun and Ankit Singh Rawat and Sashank J. Reddi and Sanjiv Kumar},
    year    = {2020},
    eprint  = {2002.07028}
}
@article{1910.05895,
    author  = {Toan Q. Nguyen and Julian Salazar},
    title   = {Transformers without Tears: Improving the Normalization of Self-Attention},
    year    = {2019},
    eprint  = {arXiv:1910.05895},
    doi     = {10.5281/zenodo.3525484},
}
@misc{bachlechner2020rezero,
    title   = {ReZero is All You Need: Fast Convergence at Large Depth},
    author  = {Thomas Bachlechner and Bodhisattwa Prasad Majumder and Huanru Henry Mao and Garrison W. Cottrell and Julian McAuley},
    year    = {2020},
    url     = {https://arxiv.org/abs/2003.04887}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@software{peng_bo_2021_5196578,
    author       = {PENG Bo},
    title        = {BlinkDL/RWKV-LM: 0.01},
    month        = {aug},
    year         = {2021},
    publisher    = {Zenodo},
    version      = {0.01},
    doi          = {10.5281/zenodo.5196578},
    url          = {https://doi.org/10.5281/zenodo.5196578}
}
Comments
  • Encoder-decoder fails at KMeans attention

    Encoder-decoder fails at KMeans attention

    I haven't been able to dig into the root cause here yet, but I'm getting the following error when trying to run an encoder-decoder:

     File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
        return func(*args, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/encoder_decoder.py", line 77, in generate
        return self.dec.generate(seq_out_start, max_seq_len, context = context, **{**dec_kwargs, **kwargs})
      File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
        return func(*args, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/autoregressive_wrapper.py", line 71, in generate
        logits, _ = self.net(x, input_mask=input_mask, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/autopadder.py", line 33, in forward
        return self.net(x, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 614, in forward
        x, loss = self.routing_transformer(x, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 592, in forward
        x, loss = self.layers(x, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 200, in forward
        out, f_loss, g_loss =  _ReversibleFunction.apply(x, blocks, args)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 137, in forward
        x, f_loss, g_loss = block(x, **kwarg)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 80, in forward
        f_out, f_loss = cast_return(self.f(x2, record_rng=self.training, **f_args), requires_grad = False)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/reversible.py", line 53, in forward
        return self.net(*args, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 121, in forward
        return self.fn(x, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 524, in forward
        global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 390, in forward
        dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 550, in __call__
        result = self.forward(*input, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 339, in forward
        self.init(x)
      File "/home/tom/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context
        return func(*args, **kwargs)
      File "/home/tom/.local/lib/python3.8/site-packages/routing_transformer/routing_transformer.py", line 325, in init
        self.means.data.copy_(means)
    RuntimeError: The size of tensor a (64) must match the size of tensor b (2) at non-singleton dimension 1
    

    Here are my model params:

    model = RoutingTransformerEncDec(
        enc_num_tokens=7000,
        dec_num_tokens=7000,
        dim=512,
        enc_ff_mult=4,
        dec_ff_mult=4,
        enc_depth=16,
        dec_depth=16,
        enc_heads=8,
        dec_heads=8,
        enc_max_seq_len=8192,
        dec_max_seq_len=8192,
        enc_window_size=128,
        dec_window_size=128,
        enc_causal=False,
        #dec_causal=True,  # decoder is always set to causal,
        enc_ff_dropout=0.05,
        dec_ff_dropout=0.05,
        enc_reversible=True,
        dec_reversible=True,
    )
    
    opened by tomweingarten 16
  • Sequence length limited

    Sequence length limited

    I tried this model, but the sequence length that the Routing Transformer can process seemed limited. I set the batch size as 16 and the sequence length as 1024, but it was out of GPU memory.

    opened by guohanyang1994 14
  • Why doesn't AutoregressiveWrapper sum the encoder aux loss?

    Why doesn't AutoregressiveWrapper sum the encoder aux loss?

    Sorry if this is a dumb question, but I couldn't find a good explanation. The auxiliary loss of the decoder is summed with the cross-entropy loss and returned for back-prorogation. The auxiliary loss of the encoder is just thrown away. What's the rationale for that? Thanks!

    opened by tomweingarten 8
  • Building and training  a RoutingTransformerEncDec from pre-trained RoutingTransformerLMs

    Building and training a RoutingTransformerEncDec from pre-trained RoutingTransformerLMs

    I am trying to build and train an encoder-decoder from pretrained routing transformer LMs. The way I approached it was to replace the encoder and decoder in a RoutingTransformerEncDec with the pre-trained RoutingTransformerLMs as follows:

    enc_dec.enc=pretrained_lm
    enc_dec.dec=AutoregressiveWrapper(pretrained_lm)
    

    and then try to train the enc_dec as normal when I get the following error:

    RuntimeError                              Traceback (most recent call last)
    <ipython-input-9-681d3315d6dc> in <module>
        147         grad_accum_steps=1,
        148         temperature=1,
    --> 149         model_suffix=''
        150 
        151     )
    
    ~/projects/trlabs_routing_transformer/routing_sum/train_and_eval.py in train_routing_single(epoch, model, tokenizer, train_chunk_bucket, val_data_bucket, model_dir, optimizer, lr, max_seq_len, pred_target_len, src_pad_len, tgt_pad_len, max_src_len, max_tgt_len, log_interval, eval_interval, save_interval, train_logger, global_step, grad_accum_steps, temperature, model_suffix)
        469         train_seq_out = padded_target[:, :max_seq_len].to(device)
        470         loss, aux_loss = model(train_seq_in, train_seq_out, return_loss=True)
    --> 471         loss.backward()
        472         aux_loss.backward()
        473         train_loss += loss.item()
    
    ~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/tensor.py in backward(self, gradient, retain_graph, create_graph)
        196                 products. Defaults to ``False``.
        197         """
    --> 198         torch.autograd.backward(self, gradient, retain_graph, create_graph)
        199 
        200     def register_hook(self, hook):
    
    ~/anaconda3/envs/routing/lib/python3.6/site-packages/torch/autograd/__init__.py in backward(tensors, grad_tensors, retain_graph, create_graph, grad_variables)
         98     Variable._execution_engine.run_backward(
         99         tensors, grad_tensors, retain_graph, create_graph,
    --> 100         allow_unreachable=True)  # allow_unreachable flag
        101 
        102 
    
    RuntimeError: new kmeans has not been supplied
    
    

    I would appreciate any feedback on what may be the problem or what is the best way to build an enc_dec from pretrained LM checkpoints.

    opened by AliOskooeiTR 7
  • top_p returns wrong values and re-orders the data

    top_p returns wrong values and re-orders the data

    from routing_transformer.autoregressive_wrapper import top_p, top_k
    import torch
    import torch.nn.functional as F
    
    test_tensor = torch.tensor([[0.1, 0.2, 0.15, 0.4, 0.3, 0.001, 0.01]])
    threshold=0.3
    
    print(test_tensor)
    print(F.softmax(test_tensor, dim=-1))
    
    print("Top K")
    print(top_k(test_tensor, thres=threshold))
    print(F.softmax(top_k(test_tensor, thres=threshold), dim=-1))
    
    print("Top P")
    print(top_p(test_tensor, thres=threshold))
    print(F.softmax(top_p(test_tensor, thres=threshold), dim=-1))
    

    Output:

    tensor([[0.1000, 0.2000, 0.1500, 0.4000, 0.3000, 0.0010, 0.0100]])
    tensor([[0.1325, 0.1464, 0.1393, 0.1789, 0.1618, 0.1200, 0.1211]])
    Top K
    tensor([[  -inf, 0.2000, 0.1500, 0.4000, 0.3000,   -inf,   -inf]])
    tensor([[0.0000, 0.2338, 0.2224, 0.2855, 0.2584, 0.0000, 0.0000]])
    Top P
    tensor([[  -inf,   -inf, 0.3000,   -inf, 0.4000,   -inf,   -inf]])
    tensor([[0.0000, 0.0000, 0.4750, 0.0000, 0.5250, 0.0000, 0.0000]])
    

    Thanks for writing this library! I think there is a bug in top_p, with two symptoms:

    1. The wrong results are filtered. It defines the threshold in the opposite way as top_k. So setting thres=0.9 results in everything being returned until a cumulative probability of 0.1 is reached.
    2. The results themselves are gathered twice (instead of gathered and scattered) and as a result the returned tensor's values are distributed in a nearly random order.

    I'll send over a PR in a few, hope it's helpful!

    opened by tomweingarten 7
  • One-hot encoded input?

    One-hot encoded input?

    I'm looking through the code, and I'm not seeing the token IDs being converted to one-hot encoded vectors. Is the input to the language model with autoregressive wrapper the token IDs?

    opened by matthew-jurewicz 4
  • Missing key(s) in state_dict

    Missing key(s) in state_dict

    Greetings, Previously was able to save and load checkpoints, but today I get: RuntimeError: Error(s) in loading state_dict for AutoregressiveWrapper: Missing key(s) in state_dict: "net.net.routing_transformer.layers.blocks.0.f.net.fn.local_attn.rel_pos.weights", "net.net.routing_transformer.layers.blocks.1.f.net.fn.local_attn.rel_pos.weights", "net.net.routing_transformer.layers.blocks.2.f.net.fn.local_attn.rel_pos.weights", "net.net.routing_transformer.layers.blocks.3.f.net.fn.local_attn.rel_pos.weights", "net.net.routing_transformer.layers.blocks.4.f.net.fn.local_attn.rel_pos.weights", "net.net.routing_transformer.layers.blocks.5.f.net.fn.local_attn.rel_pos.weights", "net.net.routing_transformer.layers.blocks.6.f.net.fn.local_attn.rel_pos.weights", "net.net.routing_transformer.layers.blocks.7.f.net.fn.local_attn.rel_pos.weights".

    Help please, Thanks

    opened by epetros 4
  • AutoregressiveWrapper expects different input lengths based on type

    AutoregressiveWrapper expects different input lengths based on type

    When the AutoregressiveWrapper receives a tensor input, it shrinks the size of the input by one. When it receives non-tensor input it applies padding. This is a bit confusing, since it means you need to provide different size inputs depending on type. Normally this wouldn't matter, but with axial position encoding it expects an exact input length, so it can fail for an input length difference of 1.

            if isinstance(x, torch.Tensor):
                xi, xo = x[:, :-1], x[:, 1:]
                annotations = annotations[:, :-1]
            else:
                xi = pad(list(map(lambda t: t[:-1], x)))
                xo = pad(list(map(lambda t: t[1:], x)))
    
    opened by tomweingarten 3
  • MoE doesn't work with reversible layers

    MoE doesn't work with reversible layers

    When reversible layers are on, the call to MoE fails because it gets the unexpected keyword "_reversible". Happy to provide a stack trace if that's helpful.

    opened by tomweingarten 2
  • ONNX export hangs

    ONNX export hangs

    I’m trying to export the model to ONNX and the export simply hangs. Before diving into debugging, wanted check if anyone has had success with an ONNX export.

    I have torch version: 1.13.0.dev20220616

    Export command:

    import torch.utils.model_zoo as model_zoo
    import torch.onnx
    import netron
    device =  torch.device('cpu')
    torch.onnx.enable_log() 
    torch.set_num_threads(1)
    #The tuple should contain model inputs such that model(*args) is a valid invocation of the model
    tinput=(torch.tensor(firstbatch['input_ids']).to(device)
           )
    model.cpu()
    model.eval()
    torch.onnx.export(model,                     # model being run  debug: export_to_pretty_string
                      tinput,                    # model input (or a tuple for multiple inputs)
                      dataset+"_routing.onnx",   # where to save the model (can be a file or file-like object)
                      export_params=True,        # store the trained parameter weights inside the model file
                      opset_version=16,          # the ONNX version to export the model to
                      do_constant_folding=True,  # whether to execute constant folding for optimization
                      input_names  = ['input_ids'],   # the model's input names
                      output_names = ['output'],  # the model's output names
                      verbose=True
                      #,operator_export_type=torch.onnx.OperatorExportTypes.ONNX_ATEN_FALLBACK
                     )
    
    opened by genolve 1
  • recieves_context cause tensor mismatch error

    recieves_context cause tensor mismatch error

    Hi all,

    I am sorry if this is a dumb question but I am running into an issue that I can't seem to solve. My model is a generative encoder/decoder model where the encoder is a vision transformer and the decoder is a routing transformer (this repo 😄 ). The output is continuous valued so I cannot use the autoregressive wrapper.

    For the longest time I used this without passing in "recieves_context" which obviously was silly and kind of circumventing the whole idea of having the ViT head. When I used the flag though I get the error below.

    Traceback (most recent call last):
      File "main.py", line 327, in <module>
        main(args)
      File "main.py", line 196, in main
        layer_loss, color_loss, position_loss, aux_loss = model(feature, pad_label, mask=pad_mask, use_activations=use_activations)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
        return forward_call(*input, **kwargs)
      File "C:\Users\jnew2\source\repos\NSA\model\model.py", line 205, in forward
        x, aux_loss = self.decoder(y, context=context, input_mask=feature_mask)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
        return forward_call(*input, **kwargs)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\routing_transformer.py", line 645, in forward
        x, loss = self.layers(x, **kwargs)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
        return forward_call(*input, **kwargs)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\reversible.py", line 171, in forward
        res, loss = cast_return(f(x, **f_args))
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
        return forward_call(*input, **kwargs)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\routing_transformer.py", line 134, in forward
        return self.fn(x, **kwargs)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
        return forward_call(*input, **kwargs)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\routing_transformer.py", line 558, in forward
        global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
        return forward_call(*input, **kwargs)
      File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\routing_transformer.py", line 374, in forward
        dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
    RuntimeError: Tensors must have same number of dimensions: got 3 and 4
    

    I ran through the code and I can definitely see where it is breaking but I really have no idea where to even start with alleviating that. For what it is worth the dims of everything are consistent:

    x = torch.Size([8, 230, 64]) context = torch.Size([8, 64]) input_mask = torch.Size([8, 230])

    and parameters I am initializing with for the routing transformer are:

    RoutingTransformer(dim = 64, depth = 2, max_seq_len = 256, heads = 16, ff_glu = True, use_scale_norm = True, causal = True, receives_context=True)

    Once again, this is probably a large issue with me misunderstanding the code but it works with other transformer architectures and I am not sure where to go.

    Cheers!

    opened by WeForgot 1
  • How to reconstruct the full attention matrix?

    How to reconstruct the full attention matrix?

    Hello,

    The implementation for the Reformer model allows for the reconstruction of the full attention matrix (https://github.com/lucidrains/reformer-pytorch#research). There, the Recorder class can expand the attention matrix to it's original form. How can one get this full attention matrix for the Routing transformer? The Recorder class is only compatible with the Reformer transformer. The full attention matrix is needed for Transformer Interpretability/Explanation, such as the one described here: https://github.com/hila-chefer/Transformer-Explainability

    I believe it would involve the lines here: https://github.com/lucidrains/routing-transformer/blob/3f6c461a036e98dbae7e70c623d1c0e0616ef82a/routing_transformer/routing_transformer.py#L407-L417

    opened by FarzanT 0
  • Enquiry about BPC Calculation

    Enquiry about BPC Calculation

    Hi, I have two questions:

    1. How to calucate the BPC score for the RoutingTransformerLM?
    2. What is the difference between the simple enwik-8 and the deepspeed enwik-8 in the folder 'examples'?
    opened by ShiweiLiuFdu 0
  • input_mask behavior

    input_mask behavior

    I have a question about how the input_mask works in RoutingTransformerLM. I have been using a random mask (with causal =False), as used in MLM and playing with the masking ratio but it appears that the ratio is not really affecting how the model learns. I even went to the extremes and masked 90% of the inputs and yet the model continued to learn rapidly. I am training the LM with HuggingFace Trainer. I am copying below my compute_loss method for reference. I have tested the mask itself and the input data and they're fine.

    def compute_loss(self, model, inputs):
    
          model_dim = self.args.model_dim
          model_seq_len = self.args.model_seq_len
    
          source = inputs["input_ids"].to(self.args.device)
          input_mask = torch.ones_like(source).bool().to(self.args.device)
          masked_tokens = random.sample(
              range(source.shape[1]),
              int(self.args.mask_ratio*source.shape[1])
          )
          input_mask[0, masked_tokens] = torch.tensor(False).to(self.args.device)
              
          output, aux_loss = model(
              source,
              input_mask=input_mask,
              return_loss=True
          )
          loss = F.cross_entropy(
              output.transpose(1, 2),
              source
          ) + aux_loss
    
          return loss.squeeze()
    
    opened by AliOskooeiTR 0
  • Music Routing Transformer Colab

    Music Routing Transformer Colab

    Hey guys,

    I had a great experience with this implementation of the Routing Transformer thanks to the efforts of @lucidrains so I wanted to share with you my creation which is based on this code/repo/implementation.

    Here is the link:

    https://github.com/asigalov61/Music-Transformers-Library/tree/main/Routing-Transformer

    I really liked that this RT trains well and quickly. And I really enjoyed the fast generation speeds and the quality of the output.

    Thank you and I hope you may find my Colab interesting and useful.

    Thanks.

    Alex

    P.S. @lucidrains GitHub has a new feature: Discussion Boards for the repos, so I would suggest enabling it everywhere so that things like creations could be shared separately from the Issues. Just my humble suggestion. Thanks.

    opened by asigalov61 1
Releases(1.6.1)
Owner
Phil Wang
Working with Attention. It's all we need
Phil Wang
Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Implementation / replication of DALL-E, OpenAI's Text to Image Transformer, in Pytorch

Phil Wang 5k Jan 2, 2023
A fast and easy implementation of Transformer with PyTorch.

FasySeq FasySeq is a shorthand as a Fast and easy sequential modeling toolkit. It aims to provide a seq2seq model to researchers and developers, which

宁羽 7 Jul 18, 2022
PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer

Cross-Covariance Image Transformer (XCiT) PyTorch implementation and pretrained models for XCiT models. See XCiT: Cross-Covariance Image Transformer L

Facebook Research 605 Jan 2, 2023
Implementation of Fast Transformer in Pytorch

Fast Transformer - Pytorch Implementation of Fast Transformer in Pytorch. This only work as an encoder. Yannic video AI Epiphany Install $ pip install

Phil Wang 167 Dec 27, 2022
A PyTorch implementation of the Transformer model in "Attention is All You Need".

Attention is all you need: A Pytorch Implementation This is a PyTorch implementation of the Transformer model in "Attention is All You Need" (Ashish V

Yu-Hsiang Huang 7.1k Jan 5, 2023
Google's Meena transformer chatbot implementation

Here's my attempt at recreating Meena, a state of the art chatbot developed by Google Research and described in the paper Towards a Human-like Open-Domain Chatbot.

Francesco Pham 94 Dec 25, 2022
A Transformer Implementation that is easy to understand and customizable.

Simple Transformer I've written a series of articles on the transformer architecture and language models on Medium. This repository contains an implem

Naoki Shibuya 4 Jan 20, 2022
Trankit is a Light-Weight Transformer-based Python Toolkit for Multilingual Natural Language Processing

Trankit: A Light-Weight Transformer-based Python Toolkit for Multilingual Natural Language Processing Trankit is a light-weight Transformer-based Pyth

null 652 Jan 6, 2023
Transformer-based Text Auto-encoder (T-TA) using TensorFlow 2.

T-TA (Transformer-based Text Auto-encoder) This repository contains codes for Transformer-based Text Auto-encoder (T-TA, paper: Fast and Accurate Deep

Jeong Ukjae 13 Dec 13, 2022
Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

T5: Text-To-Text Transfer Transformer The t5 library serves primarily as code for reproducing the experiments in Exploring the Limits of Transfer Lear

Google Research 4.6k Jan 1, 2023
Code for the paper "Exploring the Limits of Transfer Learning with a Unified Text-to-Text Transformer"

T5: Text-To-Text Transfer Transformer The t5 library serves primarily as code for reproducing the experiments in Exploring the Limits of Transfer Lear

Google Research 3.2k Feb 17, 2021
Code associated with the "Data Augmentation using Pre-trained Transformer Models" paper

Data Augmentation using Pre-trained Transformer Models Code associated with the Data Augmentation using Pre-trained Transformer Models paper Code cont

null 44 Dec 31, 2022
Learning Spatio-Temporal Transformer for Visual Tracking

STARK The official implementation of the paper Learning Spatio-Temporal Transformer for Visual Tracking Highlights The strongest performances Tracker

Multimedia Research 485 Jan 4, 2023
Transformer related optimization, including BERT, GPT

This repository provides a script and recipe to run the highly optimized transformer-based encoder and decoder component, and it is tested and maintained by NVIDIA.

NVIDIA Corporation 1.7k Jan 4, 2023
Code release for "COTR: Correspondence Transformer for Matching Across Images"

COTR: Correspondence Transformer for Matching Across Images This repository contains the inference code for COTR. We plan to release the training code

UBC Computer Vision Group 358 Dec 24, 2022
Reformer, the efficient Transformer, in Pytorch

Reformer, the Efficient Transformer, in Pytorch This is a Pytorch implementation of Reformer https://openreview.net/pdf?id=rkgNKkHtvB It includes LSH

Phil Wang 1.8k Dec 30, 2022
Segmenter - Transformer for Semantic Segmentation

Segmenter - Transformer for Semantic Segmentation

null 592 Dec 27, 2022
A look-ahead multi-entity Transformer for modeling coordinated agents.

baller2vec++ This is the repository for the paper: Michael A. Alcorn and Anh Nguyen. baller2vec++: A Look-Ahead Multi-Entity Transformer For Modeling

Michael A. Alcorn 30 Dec 16, 2022
Speech Recognition for Uyghur using Speech transformer

Speech Recognition for Uyghur using Speech transformer Training: this model using CTC loss and Cross Entropy loss for training. Download pretrained mo

Uyghur 11 Nov 17, 2022