Understanding the Difficulty of Training Transformers

Overview

License PWC

Admin

Understanding the Difficulty of Training Transformers

Guided by our analyses, we propose Adaptive Model Initialization (Admin), which successfully stabilizes previously-diverged Transformer training and achieves better performance, without introducing additional hyper-parameters. Admin is adapted for better half-precision stability and can be reparameterized into the original Transformer.

We are in an early-release beta. Expect some adventures and rough edges.

Table of Contents

Introduction

What complicates Transformer training?

In our study, we go beyond gradient vanishing and identify an amplification effect that substantially influences Transformer training. Specifically, for each layer in a multi-layer Transformer, heavy dependency on its residual branch makes training unstable, yet light dependency leads to sub-optimal performance.

Dependency and Amplification Effect

Our analysis starts from the observation that Pre-LN is more robust than Post-LN, whereas Post-LN typically leads to a better performance. As shown in Figure 1, we find these two variants have different layer dependency patterns.

With further exploration, we find that for a N-layer residual network, after updating its parameters W to W*, its outputs change is proportion to the dependency on residual branches.

Intuitively, since a larger output change indicates a more unsmooth loss surface, the large dependency complicates training. Moreover, we propose Admin (adaptive model initialization), which starts the training from the area with a smoother surface. More details can be found in our paper.

Quick Start Guide

Our implementation is based on the fairseq package (python 3.6, torch 1.5/1.6 are recommended). It can be installed by:

git clone https://github.com/LiyuanLucasLiu/Transforemr-Clinic.git
cd fairseq
pip install --editable .

The guidance for reproducing our results is available at:

Specifically, our implementation requires to first set --init-type adaptive-profiling and use one GPU for this profiling stage, then set --init-type adaptive and start training.

Citation

Please cite the following papers if you found our model useful. Thanks!

Liyuan Liu, Xiaodong Liu, Jianfeng Gao, Weizhu Chen, and Jiawei Han (2020). Understanding the Difficulty of Training Transformers. Proc. 2020 Conf. on Empirical Methods in Natural Language Processing (EMNLP'20).

@inproceedings{liu2020admin,
  title={Understanding the Difficulty of Training Transformers},
  author = {Liu, Liyuan and Liu, Xiaodong and Gao, Jianfeng and Chen, Weizhu and Han, Jiawei},
  booktitle = {Proceedings of the 2020 Conference on Empirical Methods in Natural Language Processing (EMNLP 2020)},
  year={2020}
}

Xiaodong Liu, Kevin Duh, Liyuan Liu, and Jianfeng Gao (2020). Very Deep Transformers for Neural Machine Translation. arXiv preprint arXiv:2008.07772 (2020).

@inproceedings{liu_deep_2020,
 author = {Liu, Xiaodong and Duh, Kevin and Liu, Liyuan and Gao, Jianfeng},
 booktitle = {arXiv:2008.07772 [cs]},
 title = {Very Deep Transformers for Neural Machine Translation},
 year = {2020}
}
Comments
  • Post-LN with 12-12 is trained ok, but 12-3 diverge

    Post-LN with 12-12 is trained ok, but 12-3 diverge

    Hi, As we expect, the model with more transformer layers is easier to diverge during training. However, we find that the model with 12 encoder layers and 12 decoder layers is trained ok, but the model with 12 encoders layers and 3 decoder layers diverged. Have you found this result in your experiments? Thank you

    opened by ZhenYangIACAS 9
  • wmt_en_de admin: Function 'SoftmaxBackward' returned nan values in its 0th output.

    wmt_en_de admin: Function 'SoftmaxBackward' returned nan values in its 0th output.

    I was wondering if you ever encountered nan-gradients during admin training. I'm in torch 1.6/CUDA 10.1 with no modifications to the code:

    Command

    export dd=data-bin/wmt14_en_de_joined_dict
    GPUS=0,1,2,3
    GPUID=1
    TOKEN_NUMBER=8192
    UPDATE_FREQUENCE=1
    for lnum in 18
    do
      CUDA_VISIBLE_DEVICES=$GPUID fairseq-train \
        $dd -s en -t de \
        --arch transformer_wmt_en_de --share-all-embeddings \
        --optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
        --lr-scheduler inverse_sqrt --max-update 500000 \
        --warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
        --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
        --weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
        --max-tokens $TOKEN_NUMBER --update-freq $UPDATE_FREQUENCE \
        --save-dir wmt14ende/wmt-admin-${lnum}l --restore-file x.pt --seed 1111 \
        --user-dir ../radam_fairseq --log-format simple --log-interval 500 \
        --init-type adaptive-profiling --fp16 --fp16-scale-window 256 \
        --encoder-layers $lnum --decoder-layers $lnum \
        --threshold-loss-scale 0.03125 
    
      CUDA_VISIBLE_DEVICES=$GPUS fairseq-train \
        $dd -s en -t de \
        --arch transformer_wmt_en_de --share-all-embeddings \
        --optimizer radam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \
        --lr-scheduler inverse_sqrt --max-update 500000 \
        --warmup-init-lr 1e-07 --warmup-updates 8000 --lr 0.001 --min-lr 1e-09  \
        --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \
        --weight-decay 0.0 --attention-dropout 0.1 --relu-dropout 0.1 \
        --max-tokens $TOKEN_NUMBER --update-freq $UPDATE_FREQUENCE \
        --save-dir wmt14ende/wmt-admin-${lnum}l --restore-file x.pt --seed 1111 \
        --user-dir ../radam_fairseq --log-format simple --log-interval 500 \
        --init-type adaptive --fp16 --fp16-scale-window 256 \
        --encoder-layers $lnum --decoder-layers $lnum \
        --threshold-loss-scale 0.03125 | tee ./wmt14ende/log/loss_admin-${lnum}l.log
    
      bash eval_wmt_en-de.sh wmt14ende/wmt-admin-${lnum}l $GPUID 
    done
    
    

    The profiling command works fine, but the second command raises:

    Traceback

    | WARNING: overflow detected, setting loss scale to: 32.0
    | epoch 002 | loss 4.937 | nll_loss 3.371 | ppl 10.34 | wps 24011 | ups 1 | wpb 28913.466 | bsz 942.984 | num_updates 9352 | lr 0.000
    924896 | gnorm 0.368 | clip 0.000 | oom 0.000 | loss_scale 32.000 | wall 228 | train_wall 226
    Traceback (most recent call last):
      File "/private/home/sshleifer/.conda/envs/clinic/lib/python3.8/site-packages/torch/multiprocessing/spawn.py", line 20, in _wrap
        fn(i, *args)
      File "/private/home/sshleifer/Transformer-Clinic/fairseq/fairseq_cli/train.py", line 307, in distributed_main
        main(args, init_distributed=True)
      File "/private/home/sshleifer/Transformer-Clinic/fairseq/fairseq_cli/train.py", line 90, in main
        train(args, trainer, task, epoch_itr)
      File "/private/home/sshleifer/Transformer-Clinic/fairseq/fairseq_cli/train.py", line 139, in train
        log_output = trainer.train_step(samples)
      File "/private/home/sshleifer/Transformer-Clinic/fairseq/fairseq/trainer.py", line 349, in train_step
        raise e
      File "/private/home/sshleifer/Transformer-Clinic/fairseq/fairseq/trainer.py", line 311, in train_step
        loss, sample_size, logging_output = self.task.train_step(
      File "/private/home/sshleifer/Transformer-Clinic/fairseq/fairseq/tasks/fairseq_task.py", line 264, in train_step
        optimizer.backward(loss)
      File "/private/home/sshleifer/Transformer-Clinic/fairseq/fairseq/optim/fp16_optimizer.py", line 103, in backward
        loss.backward()
      File "/private/home/sshleifer/.conda/envs/clinic/lib/python3.8/site-packages/torch/tensor.py", line 185, in backward
        torch.autograd.backward(self, gradient, retain_graph, create_graph)
      File "/private/home/sshleifer/.conda/envs/clinic/lib/python3.8/site-packages/torch/autograd/__init__.py", line 125, in backward
        Variable._execution_engine.run_backward(
    RuntimeError: Function 'SoftmaxBackward' returned nan values in its 0th output.
    

    contents of profile_ratio.init: https://gist.github.com/sshleifer/b615558499b9b10bd5bee8ddf2db030a

    Data directory:

    image

    opened by sshleifer 8
  • `RuntimeError: expected scalar type Float but found Half` during the eval step

    `RuntimeError: expected scalar type Float but found Half` during the eval step

    I was running the given script for ADMIN on en-de dataset. It throws an error at the last step which evaluates the model using the averaged checkpoint.

    Traceback (most recent call last):
    File "/home/.../bin/fairseq-generate", line 11, in load_entry_point('fairseq', 'console_scripts', 'fairseq-generate')() File "/home/.../fairseq/fairseq_cli/generate.py", line 197, in cli_main main(args) File "/home/.../fairseq/fairseq_cli/generate.py", line 111, in main hypos = task.inference_step(generator, models, sample, prefix_tokens) File "/home/.../fairseq/fairseq/tasks/fairseq_task.py", line 277, in inference_step return generator.generate(models, sample, prefix_tokens=prefix_tokens) File "/home/.../lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context return func(*args, **kwargs) File "/home/.../fairseq/fairseq/sequence_generator.py", line 113, in generate return self._generate(model, sample, **kwargs) File "/home/.../lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context return func(*args, **kwargs) File "/home/.../fairseq/fairseq/sequence_generator.py", line 152, in _generate encoder_outs = model.forward_encoder(encoder_input) File "/home/.../lib/python3.7/site-packages/torch/autograd/grad_mode.py", line 15, in decorate_context return func(*args, **kwargs) File "/home/.../fairseq/fairseq/sequence_generator.py", line 540, in forward_encoder return [model.encoder(**encoder_input) for model in self.models] File "/home/.../fairseq/fairseq/sequence_generator.py", line 540, in return [model.encoder(**encoder_input) for model in self.models] File "/home/.../lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/home/.../fairseq/fairseq/models/transformer.py", line 369, in forward x = layer(x, encoder_padding_mask) File "/home/.../lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/home/.../fairseq/fairseq/modules/transformer_layer.py", line 163, in forward x, _ = self.self_attn(query=x, key=x, value=x, key_padding_mask=encoder_padding_mask) File "/home/.../lib/python3.7/site-packages/torch/nn/modules/module.py", line 722, in _call_impl result = self.forward(*input, **kwargs) File "/home/.../fairseq/fairseq/modules/multihead_attention.py", line 141, in forward q, k, v = self.in_proj_qkv(query) File "/home/.../fairseq/fairseq/modules/multihead_attention.py", line 269, in in_proj_qkv return self._in_proj(query).chunk(3, dim=-1) File "/home/.../fairseq/fairseq/modules/multihead_attention.py", line 306, in _in_proj return F.linear(input, weight, bias) File "/home/.../lib/python3.7/site-packages/torch/nn/functional.py", line 1676, in linear output = input.matmul(weight.t()) RuntimeError: expected scalar type Float but found Half

    (Part of the path info is replaced with ... for privacy concerns. They are not useful for debugging purposes anyway.)

    opened by ruiningh 5
  • is

    is "tmp_weight" in transformer_layer.py useless?

    great work! I have two questions:

    1. is "tmp_weight" in transformer_layer.py useless? can I delete that?
    2. in the paper, you said wi is fixed when training, while in code I think it's trainable, am I right?

    thx.

    opened by zherowolf 3
  • Admin for 100L-100L model?

    Admin for 100L-100L model?

    It is mentioned in the article that 8 pieces of A100 are used to train the model. How long has it been trained and how many epochs have been reached? What is the specific performance/bleu of the final model?

    opened by Vincent131499 1
  • How to add Radam to fairseq ?

    How to add Radam to fairseq ?

    According to the process described in your README, I encountered an error, fairseq-train: error: argument --optimizer: invalid choice : 'radam'.

    I wonder how I can add Radm to fairseq.

    Thanks.

    opened by KelleyYin 1
  • argdict

    argdict

    I get an error while pre-processing the data by command wmt14en-de.sh. Upon investigation, the error is caused by the argument srcdict not being written in the parameters of the python command.

    Traceback (most recent call last): File "preprocess.py", line 359, in cli_main() File "preprocess.py", line 355, in cli_main main(args) File "preprocess.py", line 64, in main raise FileExistsError(dict_path(args.source_lang)) FileExistsError: ../data-bin/wmt14_en_de_joined_dict/dict.en.txt

    if not args.srcdict and os.path.exists(dict_path(args.source_lang)):
        raise FileExistsError(dict_path(args.source_lang))_
    

    Where is srcdict? Is it necessary? Should I change the code for it to work?

    EDIT: I changed the code from:

    python preprocess.py --source-lang en --target-lang de
    --trainpref $prep/train --validpref $prep/valid --testpref $prep/test
    --destdir ../data-bin/wmt14_en_de_joined_dict
    --joined-dictionary

    to

    fairseq-preprocess --source-lang en --target-lang de
    --trainpref $prep/train --validpref $prep/valid --testpref $prep/test
    --destdir ../data-bin/wmt14_en_de_joined_dict
    --srcdict ../data-bin/wmt14_en_de_joined_dict/dict.en.txt --tgtdict ../data-bin/wmt14_en_de_joined_dict/dict.de.txt

    was that right?

    opened by riosempre 1
  • Question about the adaptive optimizer

    Question about the adaptive optimizer

    Thanks for this great work!

    I failed to find more details about the adaptive optimizer mentioned in the paper. Could you point me any reference or github link about this adaptive optimizer?

    Thank you!

    opened by chenwydj 1
  • Difference of implementation from the original paper

    Difference of implementation from the original paper

    Hello, I really liked your paper and trying to reproduce the result. Meanwhile, I am curious about the implementation of ADMIN.

    1. On your implemented code (https://github.com/LiyuanLucasLiu/Transformer-Clinic/blob/60abd666fd18d25108636fc82ea3ac7f518df773/fairseq/fairseq/modules/transformer_layer.py#L170-L178), the variance seems to be a scalar. However the paper said it is a D-dimensional vector, so it seems to be a mismatch. Is it okay to use scalar?

    2. Furthermore, https://github.com/LiyuanLucasLiu/Transformer-Clinic/blob/60abd666fd18d25108636fc82ea3ac7f518df773/fairseq/fairseq/modules/transformer_layer.py#L176-L177 shows that the code uses the variance of both input and output, which is quite different from the original paper. The calculation of w_i on the initialization stage also seems to be different. image

    opened by wade3han 1
  • Scripts for Post-LN in Figure 10?

    Scripts for Post-LN in Figure 10?

    Hi Liyuan,

    Thank you for sharing the code! The current version already includes lots of details.

    Could you share the script for training the Post-LN Transformer without learning rate warmup (nor Admin) in Figure 10? I tried lr=3e-4, beta2=0.999, warmup-updates=1 using RAdam and the inverse square root schedule, but the model does not seem to converge.

    opened by zhuchen03 1
  • IWSLT'14 Results

    IWSLT'14 Results

    Hi, very nice paper, I like it! I have a question regarding your results on IWSLT'14. Are these values tokenized BLEU on dev or test set?

    Dataset IWSLT’14 De-En Enc #–Dec # 6L–6L (small) Post-LN 35.64 ± 0.23 Pre-LN 35.50 ± 0.04 Admin 35.67 ± 0.15

    opened by villmow 1
  • How to get the beta_{i,j} for each residual branch?

    How to get the beta_{i,j} for each residual branch?

    Figure 7 in your paper plots the weights of each residual branch for post-LN and pre-LN. I am wondering how to get these weights exactly. I note that there is a parameter called plot_variance, looks like it outputs sqrt(Var[a_j]). Then how can I get the beta_{i, j}?

    Thx for advancing!

    opened by SefaZeng 0
Owner
Liyuan Liu
Ph.D. Student @ DMG, UIUC
Liyuan Liu
Flexible interface for high-performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra.

Flexible interface for high performance research using SOTA Transformers leveraging Pytorch Lightning, Transformers, and Hydra. What is Lightning Tran

Pytorch Lightning 581 Dec 21, 2022
This codebase facilitates fast experimentation of differentially private training of Hugging Face transformers.

private-transformers This codebase facilitates fast experimentation of differentially private training of Hugging Face transformers. What is this? Why

Xuechen Li 73 Dec 28, 2022
Beyond Masking: Demystifying Token-Based Pre-Training for Vision Transformers

beyond masking Beyond Masking: Demystifying Token-Based Pre-Training for Vision Transformers The code is coming Figure 1: Pipeline of token-based pre-

Yunjie Tian 23 Sep 27, 2022
Natural language Understanding Toolkit

Natural language Understanding Toolkit TOC Requirements Installation Documentation CLSCL NER References Requirements To install nut you need: Python 2

Peter Prettenhofer 119 Oct 8, 2022
Contract Understanding Atticus Dataset

Contract Understanding Atticus Dataset This repository contains code for the Contract Understanding Atticus Dataset (CUAD), a dataset for legal contra

The Atticus Project 273 Dec 17, 2022
KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark.

KLUE Baseline Korean(한국어) KLUE-baseline contains the baseline code for the Korean Language Understanding Evaluation (KLUE) benchmark. See our paper fo

null 74 Dec 13, 2022
Disfl-QA: A Benchmark Dataset for Understanding Disfluencies in Question Answering

Disfl-QA is a targeted dataset for contextual disfluencies in an information seeking setting, namely question answering over Wikipedia passages. Disfl-QA builds upon the SQuAD-v2 (Rajpurkar et al., 2018) dataset, where each question in the dev set is annotated to add a contextual disfluency using the paragraph as a source of distractors.

Google Research Datasets 52 Jun 21, 2022
Indobenchmark are collections of Natural Language Understanding (IndoNLU) and Natural Language Generation (IndoNLG)

Indobenchmark Toolkit Indobenchmark are collections of Natural Language Understanding (IndoNLU) and Natural Language Generation (IndoNLG) resources fo

Samuel Cahyawijaya 11 Aug 26, 2022
Watson Natural Language Understanding and Knowledge Studio

Material de demonstração dos serviços: Watson Natural Language Understanding e Knowledge Studio Visão Geral: https://www.ibm.com/br-pt/cloud/watson-na

Vanderlei Munhoz 4 Oct 24, 2021
PyTorch implementation of the paper: Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding

Text is no more Enough! A Benchmark for Profile-based Spoken Language Understanding This repository contains the official PyTorch implementation of th

Xiao Xu 26 Dec 14, 2022
Universal End2End Training Platform, including pre-training, classification tasks, machine translation, and etc.

背景 安装教程 快速上手 (一)预训练模型 (二)机器翻译 (三)文本分类 TenTrans 进阶 1. 多语言机器翻译 2. 跨语言预训练 背景 TrenTrans是一个统一的端到端的多语言多任务预训练平台,支持多种预训练方式,以及序列生成和自然语言理解任务。 安装教程 git clone git

Tencent Minority-Mandarin Translation Team 42 Dec 20, 2022
Framework for fine-tuning pretrained transformers for Named-Entity Recognition (NER) tasks

NERDA Not only is NERDA a mesmerizing muppet-like character. NERDA is also a python package, that offers a slick easy-to-use interface for fine-tuning

Ekstra Bladet 141 Dec 30, 2022
KoBART model on huggingface transformers

KoBART-Transformers SKT에서 공개한 KoBART를 편리하게 사용할 수 있게 transformers로 포팅하였습니다. Install (Optional) BartModel과 PreTrainedTokenizerFast를 이용하면 설치하실 필요 없습니다. p

Hyunwoong Ko 58 Dec 7, 2022
Big Bird: Transformers for Longer Sequences

BigBird, is a sparse-attention based transformer which extends Transformer based models, such as BERT to much longer sequences. Moreover, BigBird comes along with a theoretical understanding of the capabilities of a complete transformer that the sparse model can handle.

Google Research 457 Dec 23, 2022
🤗Transformers: State-of-the-art Natural Language Processing for Pytorch and TensorFlow 2.0.

State-of-the-art Natural Language Processing for PyTorch and TensorFlow 2.0 ?? Transformers provides thousands of pretrained models to perform tasks o

Hugging Face 77.3k Jan 3, 2023
:mag: Transformers at scale for question answering & neural search. Using NLP via a modular Retriever-Reader-Pipeline. Supporting DPR, Elasticsearch, HuggingFace's Modelhub...

Haystack is an end-to-end framework for Question Answering & Neural search that enables you to ... ... ask questions in natural language and find gran

deepset 6.4k Jan 9, 2023
🛸 Use pretrained transformers like BERT, XLNet and GPT-2 in spaCy

spacy-transformers: Use pretrained transformers like BERT, XLNet and GPT-2 in spaCy This package provides spaCy components and architectures to use tr

Explosion 1.2k Jan 8, 2023