DeLighT: Very Deep and Light-Weight Transformers

Overview

DeLighT: Very Deep and Light-weight Transformers

This repository contains the source code of our work on building efficient sequence models: DeFINE (ICLR'20) and DeLighT (preprint).

Table of contents

  1. Overview
  2. Requirements and installation
  3. Training, evaluation, and results
  4. Multiplication-addition operations
  5. Citation
  6. Acknowledgement
  7. Issues

Overview

In this repository, we share the source code of our paper DeLight, that delivers similar or better performance than transformer-based models with significantly fewer parameters. DeLighT more efficiently allocates parameters both (1) within each Transformer block using DExTra, a deep and light-weight transformation and (2) across blocks using block-wise scaling, that allows for shallower and narrower DeLighT blocks near the input and wider and deeper DeLighT blocks near the output. Overall, DeLighT networks are 2.5 to 4 times deeper than standard transformer models and yet have fewer parameters and operations. For details, see our papers: DeFINE and and DeLighT.

DeLighT unit

Requirements and Installation

  • PyTorch version >= 1.4.0
  • Python version >= 3.6
  • For training new models, you'll also need an NVIDIA GPU and NCCL
  • To use DeLighT, you need to install fairseq and develop locally:
git clone https://github.com/sacmehta/delight
cd delight
pip install --editable ./
  • For faster training install NVIDIA's apex library:
git clone https://github.com/NVIDIA/apex
cd apex
pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" \
  --global-option="--deprecated_fused_adam" --global-option="--xentropy" \
  --global-option="--fast_multihead_attn" ./

Training, Evaluation, and Results

For training, evaluation, and results, see below links. To ease reproduction of our results, we also provide links to training logs.

Neural machine translation

Language Modeling

Multiplication-Addition Operations

We have added module profiling for both Transformer and DeLight networks. This can be enabled using --print-stats argument. A model summary will be printed (by default for 20 tokens), similar to below screenshot. To use larger sequence lengths for source and target for profiling statistics, you can use --src-len-ps and --tgt-len-ps flags.

Model statistics

Citation

If you find our work useful, please consider citing following works:

@misc{mehta2020delight,
    title={DeLighT: Very Deep and Light-weight Transformer},
    author={Sachin Mehta and Marjan Ghazvininejad and Srinivasan Iyer and Luke Zettlemoyer and Hannaneh Hajishirzi},
    year={2020},
    eprint={2008.00623},
    archivePrefix={arXiv},
    primaryClass={cs.LG}
}
@inproceedings{mehta2019define,
  title={DeFINE: Deep Factorized Input Token Embeddings for Neural Sequence Modeling},
  author={Mehta, Sachin and Koncel-Kedziorski, Rik and Rastegari, Mohammad and Hajishirzi, Hannaneh},
  booktitle={International Conference on Learning Representations},
  year={2019}
}

Acknowledgements

We would like to thank Fairseq team for building easy-to-use sequence library.

Issues

Thanks for your interest in our work. For any issues, please raise a request.

Comments
  • Question about WMT EN-RO

    Question about WMT EN-RO

    Hi, I have the following three questions about the experiment of EN-RO.

    • In the paper, you mention that the batch size is 64k and training updates is 100K. But WMT 16 EN-RO is a smaller dataset which only consists of only 0.6M training examples. I wonder if training with those large batch size and training so long will not overfit.
    • Why is wpb=21k in https://gist.github.com/sacmehta/57c12358434f12bf15939311469c7173#file-delight_wmt16_en2ro_dm_384-txt?
    • Also, can you provide the evaluation script you use to compute the BLUE score of WMT EN-RO?
    opened by luofuli 6
  • Unable to do fp16 training.

    Unable to do fp16 training.

    It is mentioned to install apex but in the training command no option has been given. Tried to use default --fp16 command from fairseq but getting the below error.

    RuntimeError: CUDA error: CUBLAS_STATUS_EXECUTION_FAILED when calling `cublasGemmEx( handle, opa, opb, m, n, k, &falpha, a, CUDA_R_16F, lda, b, CUDA_R_16F, ldb, &fbeta, c, CUDA_R_16F, ldc, CUDA_R_32F, CUB
    LAS_GEMM_DFALT_TENSOR_OP)` 
    

    Want to do to train on --fp16 . Please suggest. Thanks.

    opened by sugeeth14 4
  • no setup.py in local fairseq directory

    no setup.py in local fairseq directory

    Hi,

    Thank you for your interesting work. I was trying to reproduce some of the translation task and learn the model better. But it seems there is no setup.pt in fairseq and I cannot install fairseq from this source.

    ERROR: File "setup.py" not found. Directory cannot be installed in editable mode: /home/qingyu.tan/projects/delight/fairseq (base)

    Should I copy an setup.py from fairseq or do you have a modified version of setup.py?

    opened by tonytan48 4
  • apex installation error

    apex installation error

    Hi delight authors,

    I am using a Linux system with cuda=10.1, pytorch=1.5.0 to reproduce your great work. I encountered an error when I run the following command:

    pip install -v --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" --global-option="--deprecated_fused_adam" --global-option="--xentropy" --global-option="--fast_multihead_attn" ./

    The error seems that ATen is not found:

        In file included from apex/contrib/csrc/multihead_attn/additive_masked_softmax_dropout_cuda.cu:14:0:
        apex/contrib/csrc/multihead_attn/softmax.h:2:10: fatal error: ATen/CUDAGeneratorImpl.h: No such file or directory
         #include <ATen/CUDAGeneratorImpl.h>
                  ^~~~~~~~~~~~~~~~~~~~~~~~~~
        compilation terminated.
        error: command '/home/qian/anaconda3/envs/home/pkgs/cuda-toolkit/bin/nvcc' failed with exit status 1
        Running setup.py install for apex ... error
    ERROR: Command errored out with exit status 1: /home/qian/anaconda3/envs/home/bin/python -u -c 'import sys, setuptools, tokenize; sys.argv[0] = '"'"'/tmp/pip-req-build-ioaif4t7/setup.py'"'"'; __file__='"'"'/tmp/pip-req-build-ioaif4t7/setup.py'"'"';f=getattr(tokenize, '"'"'open'"'"', open)(__file__);code=f.read().replace('"'"'\r\n'"'"', '"'"'\n'"'"');f.close();exec(compile(code, __file__, '"'"'exec'"'"'))' --cpp_ext --cuda_ext --deprecated_fused_adam --xentropy --fast_multihead_attn install --record /tmp/pip-record-2v5qna5p/install-record.txt --single-version-externally-managed --compile --install-headers /home/qian/anaconda3/envs/home/include/python3.8/apex Check the logs for full command output.
    Exception information:
    Traceback (most recent call last):
      File "/home/qian/anaconda3/envs/home/lib/python3.8/site-packages/pip/_internal/req/req_install.py", line 812, in install
        success = install_legacy(
      File "/home/qian/anaconda3/envs/home/lib/python3.8/site-packages/pip/_internal/operations/install/legacy.py", line 86, in install
        raise LegacyInstallFailure
    pip._internal.operations.install.legacy.LegacyInstallFailure
    

    Did you meet similar issues? Thanks very much for any response!

    opened by qianlou 1
  • Fixed typo to reflect consistency

    Fixed typo to reflect consistency

    The file nmt_wmt14_en2de.py has default of data-bin/wmt14_en_de but the file prepare_nmt_dataset.shcreatesdata-bin/wmt17_en_de . Changed to maintain consistency.

    opened by sugeeth14 1
  • train.py: error: unrecognized arguments: --t-mult 1

    train.py: error: unrecognized arguments: --t-mult 1

    Dear author: when I run script python nmt_wmt16_en2ro.py --d-m 384 the following error will be given train.py: error: unrecognized arguments: --t-mult 1

    What's more, when I read the code detailly. I can't find the arg '--t-mult'. Below is my error log:

    $ python nmt_wmt16_en2ro.py --d-m 384 2022-01-10 15:23:03 - LOGS - Training command: python train.py data-bin/wmt14_en_ro --arch delight_transformer_wmt16_en_ro --no-progress-bar --optimizer adam --adam- betas '(0.9, 0.98)' --clip-norm 0.0 --weight-decay 0.0 --criterion label_smoothed_cross_entropy --label-smoothing 0.1 --min-lr 1e-09 --update-freq 1 --keep-last-epochs 10 --ddp-backend=no_c10d --max-tokens 4096 --max-update 100000 --warmup-updates 10000 --lr-scheduler linear --warmup-init-lr 1e-7 --lr 0.0009 --min-lr 1e-9 --t-mult 1 --save-dir ./results_wmt16_en2ro/delight_out_384 --distributed-world-size 8 --distributed-port 50786 --delight-emb-map-dim 128 --delight-emb-out-dim 384 --delight-enc-min-depth 4 --delight-enc-max-depth 8 --delight-enc-width-mult 2 --delight-dec-min-depth 4 --delight-dec-max-depth 8 --delight-dec-width-mult 2 | tee -a ./results_wmt16_en2ro/delight_out_384/logs.txt usage: train.py [-h] [--no-progress-bar] [--log-interval N] [--log-format {json,none,simple,tqdm}] [--tensorboard-logdir DIR] [--seed N] [--cpu] [--fp16] [--memory-efficient-fp16] [--fp16-no-flatten-grads] [--fp16-init-scale FP16_INIT_SCALE] [--fp16-scale-window FP16_SCALE_WINDOW] [--fp16-scale-tolerance FP16_SCALE_TOLERANCE] [--min-loss-scale D] [--threshold-loss-scale THRESHOLD_LOSS_SCALE] [--user-dir USER_DIR] [--empty-cache-freq EMPTY_CACHE_FREQ] [--all-gather-list-size ALL_GATHER_LIST_SIZE] [--criterion {label_smoothed_cross_entropy,sentence_ranking,legacy_masked_lm_loss,composite_loss,label_smoothed_cross_entropy_with_alignment,adaptive_loss,adaptive_cross_entropy,nat_loss,sentence_prediction,masked_lm,cross_entropy,binary_cross_entropy}] [--tokenizer {moses,nltk,space}] [--bpe {fastbpe,subword_nmt,bert,sentencepiece,gpt2}] [--optimizer {adadelta,adamax,adagrad,adafactor,sgd,lamb,nag,adam}] [--lr-scheduler {cosine,inverse_sqrt,linear,triangular,fixed,reduce_lr_on_plateau,polynomial_decay,tri_stage}] [--task TASK] [--num-workers N] [--skip-invalid-size-inputs-valid-test] [--max-tokens N] [--max-sentences N] [--required-batch-size-multiple N] [--dataset-impl FORMAT] [--train-subset SPLIT] [--valid-subset SPLIT] [--validate-interval N] [--fixed-validation-seed N] [--disable-validation] [--max-tokens-valid N] [--max-sentences-valid N] [--curriculum N] [--distributed-world-size N] [--distributed-rank DISTRIBUTED_RANK] [--distributed-backend DISTRIBUTED_BACKEND] [--distributed-init-method DISTRIBUTED_INIT_METHOD] [--distributed-port DISTRIBUTED_PORT] [--device-id DEVICE_ID] [--distributed-no-spawn] [--ddp-backend {c10d,no_c10d}] [--bucket-cap-mb MB] [--fix-batches-to-gpus] [--find-unused-parameters] [--fast-stat-sync] [--broadcast-buffers] [--arch ARCH] [--max-epoch N] [--max-update N] [--clip-norm NORM] [--sentence-avg] [--update-freq N1,N2,...,N_K] [--lr LR_1,LR_2,...,LR_N] [--min-lr LR] [--use-bmuf] [--save-dir DIR] [--restore-file RESTORE_FILE] [--reset-dataloader] [--reset-lr-scheduler] [--reset-meters] [--reset-optimizer] [--optimizer-overrides DICT] [--save-interval N] [--save-interval-updates N] [--keep-interval-updates N] [--keep-last-epochs N] [--keep-best-checkpoints N] [--no-save] [--no-epoch-checkpoints] [--no-last-checkpoints] [--no-save-optimizer-state] [--best-checkpoint-metric BEST_CHECKPOINT_METRIC] [--maximize-best-checkpoint-metric] [--patience N] [--adaptive-input] [--adaptive-softmax-cutoff EXPR] [--adaptive-softmax-dropout D] [--adaptive-softmax-factor N] [--tie-adaptive-weights] [--tie-adaptive-proj] [--delight-emb-map-dim DELIGHT_EMB_MAP_DIM] [--delight-emb-out-dim DELIGHT_EMB_OUT_DIM] [--delight-emb-width-mult DELIGHT_EMB_WIDTH_MULT] [--delight-emb-max-groups DELIGHT_EMB_MAX_GROUPS] [--delight-emb-dropout DELIGHT_EMB_DROPOUT] [--delight-emb-depth DELIGHT_EMB_DEPTH] [--delight-enc-scaling {block,uniform}] [--delight-enc-layers DELIGHT_ENC_LAYERS] [--delight-enc-min-depth DELIGHT_ENC_MIN_DEPTH] [--delight-enc-max-depth DELIGHT_ENC_MAX_DEPTH] [--delight-enc-width-mult DELIGHT_ENC_WIDTH_MULT] [--delight-enc-ffn-red DELIGHT_ENC_FFN_RED] [--delight-enc-max-groups DELIGHT_ENC_MAX_GROUPS] [--delight-dec-scaling {block,uniform}] [--delight-dec-layers DELIGHT_DEC_LAYERS] [--delight-dec-min-depth DELIGHT_DEC_MIN_DEPTH] [--delight-dec-max-depth DELIGHT_DEC_MAX_DEPTH] [--delight-dec-width-mult DELIGHT_DEC_WIDTH_MULT] [--delight-dec-ffn-red DELIGHT_DEC_FFN_RED] [--delight-dec-max-groups DELIGHT_DEC_MAX_GROUPS] [--no-glt-shuffle] [--define-iclr] [--norm-type NORM_TYPE] [--act-type ACT_TYPE] [--delight-dropout DELIGHT_DROPOUT] [--ffn-dropout FFN_DROPOUT] [--print-stats] [--src-len-ps SRC_LEN_PS] [--tgt-len-ps TGT_LEN_PS] [--dropout D] [--attention-dropout D] [--pe-dropout D] [--activation-dropout D] [--encoder-normalize-before] [--decoder-normalize-before] [--share-decoder-input-output-embed] [--share-all-embeddings] [--decoder-learned-pos] [--encoder-learned-pos] [--no-token-positional-embeddings] [--no-scale-embedding] [--label-smoothing D] [--adam-betas B] [--adam-eps D] [--weight-decay WD] [--use-old-adam] [--warmup-updates N] [--warmup-init-lr LR] [-s SRC] [-t TARGET] [--load-alignments] [--left-pad-source BOOL] [--left-pad-target BOOL] [--max-source-positions N] [--max-target-positions N] [--upsample-primary UPSAMPLE_PRIMARY] [--truncate-source] [--eval-bleu] [--eval-bleu-detok EVAL_BLEU_DETOK] [--eval-bleu-detok-args JSON] [--eval-tokenized-bleu] [--eval-bleu-remove-bpe [EVAL_BLEU_REMOVE_BPE]] [--eval-bleu-args JSON] [--eval-bleu-print-samples] data train.py: error: unrecognized arguments: --t-mult 1 thank you

    It looks like the format distorted by github. I paste it below

    https://paste.ofcode.org/fkBdqtjQdEFr6QeymGY49F

    plz take a look if you need.

    opened by fkjslee 0
  • Using this architecture to ASR system.

    Using this architecture to ASR system.

    I have one more question.

    I want to use this DeLighT architecture to ASR system not to translation. Is this possible for me to do this without using fairseq? Have you ever applied this to ASR? I need some help about this.

    Thank you.

    opened by miziworld 0
  • Training by using other language data

    Training by using other language data

    Hello, Thank you for your research.

    I want to implement this project by using ko-en language data. However, I think the data format is a little different with NMT dataset.

    I have the ko-en dataset with this format ( utter, Ko-sentence, En-sentence ) with xlsx file.

    How can I use this dataset on your project?

    opened by miziworld 0
  • Naive question about residual connection

    Naive question about residual connection

    https://github.com/sacmehta/delight/blob/cc499c53087cd248ee7a0d0b0e70c507e670cba3/fairseq/modules/delight_transformer_layer.py#L120

    Thank you for interesting work. I have a very naive query, in the code above, when a tensor is assigned to another tensor, as they share the data, doesn't residual has same value as x when x is modified in the subsequent code ? because later in code we add the residual back to x to make a skip connection.

    opened by gopi231091 1
  • Failed to reimplement the exps on iwslt'14 de-en

    Failed to reimplement the exps on iwslt'14 de-en

    Hi, I got some issues with the reimplementation of models trained on iwslt'14 de-en. The hyper-parameters of DeLighT(d_m=512) were set as https://github.com/pytorch/fairseq/blob/master/examples/translation/README.md, like CUDA_VISIBLE_DEVICES=0 fairseq-train \ data-bin/iwslt14.tokenized.de-en \ --arch transformer_iwslt_de_en --share-decoder-input-output-embed \ --optimizer adam --adam-betas '(0.9, 0.98)' --clip-norm 0.0 \ --lr 5e-4 --lr-scheduler inverse_sqrt --warmup-updates 4000 \ --dropout 0.3 --weight-decay 0.0001 \ --criterion label_smoothed_cross_entropy --label-smoothing 0.1 \ --max-tokens 4096 \ --eval-bleu \ --eval-bleu-args '{"beam": 5, "max_len_a": 1.2, "max_len_b": 10}' \ --eval-bleu-detok moses \ --eval-bleu-remove-bpe \ --eval-bleu-print-samples \ --best-checkpoint-metric bleu --maximize-best-checkpoint-metric

    Also come with the settings for DeLighT block: --delight-enc-min-depth 3 --delight-enc-max-depth 9 --delight-enc-width-mult 1 --delight-dec-min-depth 3 --delight-dec-max-depth 9 --delight-dec-width-mult 1

    However, the resulting model got 31.2 on BLEU, much worse than the performance (35.3) reported on the manuscript. Further, the parameters of it did not match with 30M, it's 33M totally.

    The hyper-parameters need to be corrected I guess, is there anyone got the same issues?

    opened by CheerM 10
Owner
Sachin Mehta
Ph.D. Student at University of Washington
Sachin Mehta
Technical Indicators implemented in Python only using Numpy-Pandas as Magic - Very Very Fast! Very tiny! Stock Market Financial Technical Analysis Python library . Quant Trading automation or cryptocoin exchange

MyTT Technical Indicators implemented in Python only using Numpy-Pandas as Magic - Very Very Fast! to Stock Market Financial Technical Analysis Python

dev 34 Dec 27, 2022
Official code for paper "Demystifying Local Vision Transformer: Sparse Connectivity, Weight Sharing, and Dynamic Weight"

Demysitifing Local Vision Transformer, arxiv This is the official PyTorch implementation of our paper. We simply replace local self attention by (dyna

null 138 Dec 28, 2022
Convert weight file.pth to weight file.blob

CONVERT YOUR MODEL TO IR FORMAT INSTALLATION OpenVino Toolkit Download openvinotoolkit 2021.3 version : Link Instruction of installation : Link Pytorc

Tran Anh Tuan 3 Nov 18, 2021
A highly efficient, fast, powerful and light-weight anime downloader and streamer for your favorite anime.

AnimDL - Download & Stream Your Favorite Anime AnimDL is an incredibly powerful tool for downloading and streaming anime. Core features Abuses the dev

KR 759 Jan 8, 2023
Official Tensorflow implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"

M-LSD: Towards Light-weight and Real-time Line Segment Detection Official Tensorflow implementation of "M-LSD: Towards Light-weight and Real-time Line

NAVER/LINE Vision 357 Jan 4, 2023
Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Detection"

M-LSD: Towards Light-weight and Real-time Line Segment Detection Pytorch implementation of "M-LSD: Towards Light-weight and Real-time Line Segment Det

null 123 Jan 4, 2023
Unofficial PyTorch implementation of MobileViT based on paper "MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer".

MobileViT RegNet Unofficial PyTorch implementation of MobileViT based on paper MOBILEVIT: LIGHT-WEIGHT, GENERAL-PURPOSE, AND MOBILE-FRIENDLY VISION TR

Hong-Jia Chen 91 Dec 2, 2022
A light weight data augmentation tool for training CNNs and Viola Jones detectors

hey-daug A light weight data augmentation tool for training CNNs and Viola Jones detectors (Haar Cascades). This tool inflates your data by up to six

Jaiyam Sharma 2 Nov 23, 2019
A light-weight image labelling tool for Python designed for creating segmentation data sets.

An image labelling tool for creating segmentation data sets, for Django and Flask.

null 117 Nov 21, 2022
Light-weight network, depth estimation, knowledge distillation, real-time depth estimation, auxiliary data.

light-weight-depth-estimation Boosting Light-Weight Depth Estimation Via Knowledge Distillation, https://arxiv.org/abs/2105.06143 Junjie Hu, Chenyou F

Junjie Hu 13 Dec 10, 2022
Official repository for the paper "Going Beyond Linear Transformers with Recurrent Fast Weight Programmers"

Recurrent Fast Weight Programmers This is the official repository containing the code we used to produce the experimental results reported in the pape

IDSIA 36 Nov 15, 2022
PyTorch Implementation of "Light Field Image Super-Resolution with Transformers"

LFT PyTorch implementation of "Light Field Image Super-Resolution with Transformers", arXiv 2021. [pdf]. Contributions: We make the first attempt to a

Squidward 62 Nov 28, 2022
A Light CNN for Deep Face Representation with Noisy Labels

A Light CNN for Deep Face Representation with Noisy Labels Citation If you use our models, please cite the following paper: @article{wulight, title=

Alfred Xiang Wu 715 Nov 5, 2022
Multivariate Time Series Forecasting with efficient Transformers. Code for the paper "Long-Range Transformers for Dynamic Spatiotemporal Forecasting."

Spacetimeformer Multivariate Forecasting This repository contains the code for the paper, "Long-Range Transformers for Dynamic Spatiotemporal Forecast

QData 440 Jan 2, 2023
Very deep VAEs in JAX/Flax

Very Deep VAEs in JAX/Flax Implementation of the experiments in the paper Very Deep VAEs Generalize Autoregressive Models and Can Outperform Them on I

Jamie Townsend 42 Dec 12, 2022
StackRec: Efficient Training of Very Deep Sequential Recommender Models by Iterative Stacking

StackRec: Efficient Training of Very Deep Sequential Recommender Models by Iterative Stacking Datasets You can download datasets that have been pre-pr

null 25 May 29, 2022
PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"

PyTorch code for our ECCV 2018 paper "Image Super-Resolution Using Very Deep Residual Channel Attention Networks"

Yulun Zhang 1.2k Dec 26, 2022
Channel Pruning for Accelerating Very Deep Neural Networks (ICCV'17)

Channel Pruning for Accelerating Very Deep Neural Networks (ICCV'17)

Yihui He 1k Jan 3, 2023
Image Super-Resolution Using Very Deep Residual Channel Attention Networks

Image Super-Resolution Using Very Deep Residual Channel Attention Networks

kongdebug 14 Oct 14, 2022