Source code for NAACL 2021 paper "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference"

Overview

TR-BERT

Source code and dataset for "TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference".

model

The code is based on huggaface's transformers. Thanks to them! We will release all the source code in the future.

Requirement

Install dependencies and apex:

pip3 install -r requirement.txt
pip3 install --editable transformers

Pretrained models

Download the DistilBERT-3layer and BERT-1024 from Google Drive/Tsinghua Cloud.

Classfication

Download the IMDB, Yelp, 20News datasets from Google Drive/Tsinghua Cloud.

Download the Hyperpartisan dataset, and randomly split it into train/dev/test set: python3 split_hyperpartisan.py

Train BERT/DistilBERT Model

Use flag --do train:

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path bert-base-uncased --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 16 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 5  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval  --evaluate_during_training  --do_train

where task_name can be set as imdb/yelp_f/20news/hyperpartisan for different tasks and model type can be set as bert/distilbert for different models.

Compute Graident for Residual Strategy

Use flag --do_eval_grad.

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval_grad

This step doesn't supoort data DataParallel or DistributedDataParallel currently and should be done in a single GPU.

Train the policy network solely

Start from the checkpoint from the task-specific fine-tuned model. Change model_type from bert to autobert, and run with flag --do_train --train_rl:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/auto_1  --do_lower_case  --do_train --train_rl --alpha 1 --guide_rate 0.5

where alpha is the harmonic coefficient for the length punishment and guide_rate is the proportion of imitation learning steps. model_type can be set as autobert/distilautobert for applying token reduction to BERT/DistilBERT.

Compute Logits for Knowledge Distilation

Use flag --do_eval_logits.

python3 run_classification.py  --task_name imdb  --model_type bert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8  --output_dir imdb_models/bert_base  --do_lower_case  --do_eval_logits

This step doesn't supoort data DataParallel or DistributedDataParallel currently and should be done in a single GPU.

Train the whole network with both the task-specifc objective and RL objective

Start from the checkpoint from --train_rl model and run with flag --do_train --train_both --train_teacher:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/auto_1 --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 1 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/auto_1_both  --do_lower_case  --do_train --train_both --train_teacher --alpha 1

Evaluate

Use flag --do_eval:

python3 run_classification.py  --task_name imdb  --model_type autobert  --model_name_or_path imdb_models/auto_1_both  --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 1  --output_dir imdb_models/auto_1_both  --do_lower_case  --do_eval --eval_all_checkpoints

When the batch size is more than 1 in evaluating, we will remain the same number of tokens for each instance in the same batch.

Initialize

For IMDB dataset, we find that when we directly initialize the selector with heuristic objective before train the policy network solely, we can get a bit better performance. For other datasets, this step makes little change. Run this step with flag --do_train --train_init:

python3 trans_imdb_rank.py
python3 run_classification.py  --task_name imdb  --model_type initbert  --model_name_or_path imdb_models/bert_base --data_dir imdb --max_seq_length 512  --per_gpu_train_batch_size 8  --per_gpu_eval_batch_size 8 --gradient_accumulation_steps 4 --learning_rate 3e-5 --save_steps 2000  --num_train_epochs 3  --output_dir imdb_models/bert_init  --do_lower_case  --do_train --train_init 

Question Answering

Download the SQuAD 2.0 dataset.

Download the MRQA dataset with our split] from Google Drive/Tsinghua Cloud.

Download the HotpotQA dataset from the Transformer-XH repository, where paragraphs are retrieved for each question according to TF-IDF, entity linking and hyperlink and re-ranked by BERT re-ranker.

Download the TriviaQA dataset, where paragraphs are re-rank by the linear passage re-ranker in DocQA.

Download the WikiHop dataset.

The whole training progress of question answer models is similiar to text classfication models, with flags --do_train, --do_train --train_rl, --do_train --train_both --train_teacher in turn. The codes of each dataset:

SQuAD: run_squad.py with flag version_2_with_negative

NewsQA / NaturalQA: run_mrqa.py

RACE: run_race_classify.py

HotpotQA: run_hotpotqa.py

TriviaQA: run_triviaqa.py

WikiHop: run_wikihop.py

Harmonic Coefficient Lambda

The example harmonic coefficients are shown as follows:

Dataset train_rl train_both
SQuAD 2.0 5 5
NewsQA 3 5
NaturalQA 2 2
RACE 0.5 0.1
YELP.F 2 0.5
20News 1 1
IMDB 1 1
HotpotQA 0.1 4
TriviaQA 0.5 1
Hyperparisan 0.01 0.01

Cite

If you use the code, please cite this paper:

@inproceedings{ye2021trbert,
  title={TR-BERT: Dynamic Token Reduction for Accelerating BERT Inference},
  author={Deming Ye, Yankai Lin, Yufei Huang, Maosong Sun},
  booktitle={Proceedings of NAACL 2021},
  year={2021}
}
You might also like...
Self-training with Weak Supervision (NAACL 2021)
Self-training with Weak Supervision (NAACL 2021)

This repo holds the code for our weak supervision framework, ASTRA, described in our NAACL 2021 paper: "Self-Training with Weak Supervision"

Paddle implementation for "Highly Efficient Knowledge Graph Embedding Learning with Closed-Form Orthogonal Procrustes Analysis" (NAACL 2021)

ProcrustEs-KGE Paddle implementation for Highly Efficient Knowledge Graph Embedding Learning with Orthogonal Procrustes Analysis ๐Ÿ™ˆ A more detailed re

Paddle implementation for "Cross-Lingual Word Embedding Refinement by โ„“1 Norm Optimisation" (NAACL 2021)

L1-Refinement Paddle implementation for "Cross-Lingual Word Embedding Refinement by โ„“1 Norm Optimisation" (NAACL 2021) ๐Ÿ™ˆ A more detailed readme is co

Open-Ended Commonsense Reasoning (NAACL 2021)
Open-Ended Commonsense Reasoning (NAACL 2021)

Open-Ended Commonsense Reasoning Quick links: [Paper] | [Video] | [Slides] | [Documentation] This is the repository of the paper, Differentiable Open-

Pytorch implementation of Supporting Clustering with Contrastive Learning, NAACL 2021

Supporting Clustering with Contrastive Learning SCCL (NAACL 2021) Dejiao Zhang, Feng Nan, Xiaokai Wei, Shangwen Li, Henghui Zhu, Kathleen McKeown, Ram

โœ… How Robust are Fact Checking Systems on Colloquial Claims?. In NAACL-HLT, 2021.
โœ… How Robust are Fact Checking Systems on Colloquial Claims?. In NAACL-HLT, 2021.

How Robust are Fact Checking Systems on Colloquial Claims? Official PyTorch implementation of our NAACL paper: Byeongchang Kim*, Hyunwoo Kim*, Seokhee

Empirical Study of Transformers for Source Code & A Simple Approach for Handling Out-of-Vocabulary Identifiers in Deep Learning for Source Code

Transformers for variable misuse, function naming and code completion tasks The official PyTorch implementation of: Empirical Study of Transformers fo

This is the official source code for SLATE. We provide the code for the model, the training code, and a dataset loader for the 3D Shapes dataset. This code is implemented in Pytorch.

SLATE This is the official source code for SLATE. We provide the code for the model, the training code and a dataset loader for the 3D Shapes dataset.

Source code, datasets and trained models for the paper Learning Advanced Mathematical Computations from Examples (ICLR 2021), by Franรงois Charton, Amaury Hayat (ENPC-Rutgers) and Guillaume Lample

Maths from examples - Learning advanced mathematical computations from examples This is the source code and data sets relevant to the paper Learning a

Owner
THUNLP
Natural Language Processing Lab at Tsinghua University
THUNLP
Source code for paper "ATP: AMRize Than Parse! Enhancing AMR Parsing with PseudoAMRs" @NAACL-2022

ATP: AMRize Then Parse! Enhancing AMR Parsing with PseudoAMRs Hi this is the source code of our paper "ATP: AMRize Then Parse! Enhancing AMR Parsing w

Chen Liang 13 Nov 23, 2022
Code for NAACL 2021 full paper "Efficient Attentions for Long Document Summarization"

LongDocSum Code for NAACL 2021 paper "Efficient Attentions for Long Document Summarization" This repository contains data and models needed to reprodu

null 56 Jan 2, 2023
Official repository with code and data accompanying the NAACL 2021 paper "Hurdles to Progress in Long-form Question Answering" (https://arxiv.org/abs/2103.06332).

Hurdles to Progress in Long-form Question Answering This repository contains the official scripts and datasets accompanying our NAACL 2021 paper, "Hur

Kalpesh Krishna 41 Nov 8, 2022
Codes for NAACL 2021 Paper "Unsupervised Multi-hop Question Answering by Question Generation"

Unsupervised-Multi-hop-QA This repository contains code and models for the paper: Unsupervised Multi-hop Question Answering by Question Generation (NA

Liangming Pan 70 Nov 27, 2022
Official code of our work, Unified Pre-training for Program Understanding and Generation [NAACL 2021].

PLBART Code pre-release of our work, Unified Pre-training for Program Understanding and Generation accepted at NAACL 2021. Note. A detailed documentat

Wasi Ahmad 138 Dec 30, 2022
Code for paper "Document-Level Argument Extraction by Conditional Generation". NAACL 21'

Argument Extraction by Generation Code for paper "Document-Level Argument Extraction by Conditional Generation". NAACL 21' Dependencies pytorch=1.6 tr

Zoey Li 87 Dec 26, 2022
Designing a Minimal Retrieve-and-Read System for Open-Domain Question Answering (NAACL 2021)

Designing a Minimal Retrieve-and-Read System for Open-Domain Question Answering Abstract In open-domain question answering (QA), retrieve-and-read mec

Clova AI Research 34 Apr 13, 2022
NAACL'2021: Factual Probing Is [MASK]: Learning vs. Learning to Recall

OptiPrompt This is the PyTorch implementation of the paper Factual Probing Is [MASK]: Learning vs. Learning to Recall. We propose OptiPrompt, a simple

Princeton Natural Language Processing 150 Dec 20, 2022
Contextualized Perturbation for Textual Adversarial Attack, NAACL 2021

Contextualized Perturbation for Textual Adversarial Attack Introduction This is a PyTorch implementation of Contextualized Perturbation for Textual Ad

cookielee77 30 Jan 1, 2023
[NAACL & ACL 2021] SapBERT: Self-alignment pretraining for BERT.

SapBERT: Self-alignment pretraining for BERT This repo holds code for the SapBERT model presented in our NAACL 2021 paper: Self-Alignment Pretraining

Cambridge Language Technology Lab 104 Dec 7, 2022