TRICE: a task-agnostic transferring framework for multi-source sequence generation
This is the source code of our work Transfer Learning for Sequence Generation: from Single-source to Multi-source (ACL 2021).
We propose TRICE, a task-agnostic Transferring fRamework for multI-sourCe sEquence generation, for transferring pretrained models to multi-source sequence generation tasks (e.g., automatic post-editing, multi-source translation, and multi-document summarization). TRICE achieves new state-of-the-art results on the WMT17 APE task and the multi-source translation task using the WMT14 test set. Welcome to take a quick glance at our blog.
The implementation is on top of the open-source NMT toolkit THUMT.
@misc{huang2021transfer,
title={Transfer Learning for Sequence Generation: from Single-source to Multi-source},
author={Xuancheng Huang and Jingfang Xu and Maosong Sun and Yang Liu},
year={2021},
eprint={2105.14809},
archivePrefix={arXiv},
primaryClass={cs.CL}
}
Contents
Prerequisites
- Python >= 3.6
- tensorflow-cpu >= 2.0
- torch >= 1.7
- transformers >= 3.4
- sentencepiece >= 0.1
Pretrained model
We adopt mbart-large-cc25 in our experiments. Other sequence-to-sequence pretrained models can also be used with only a few modifications.
If your GPUs do not have enough memories, you can prune the original large vocabulary (25k) to a small vocabulary (e.g., 3k) with little performance loss.
Finetuning
Single-source finetuning
PYTHONPATH=${path_to_TRICE} \
python ${path_to_TRICE}/thumt/bin/trainer.py \
--input ${train_src1} ${train_src2} ${train_trg} \
--vocabulary ${vocab_joint} ${vocab_joint} \
--validation ${dev_src1} ${dev_src2} \
--references ${dev_ref} \
--model transformer --half --hparam_set big \
--output single_finetuned \
--parameters \
fixed_batch_size=false,batch_size=820,train_steps=120000,update_cycle=5,device_list=[0,1,2,3],\
keep_checkpoint_max=2,save_checkpoint_steps=2000,\
eval_steps=2001,decode_alpha=1.0,decode_batch_size=16,keep_top_checkpoint_max=1,\
attention_dropout=0.1,relu_dropout=0.1,residual_dropout=0.1,learning_rate=5e-05,warmup_steps=4000,initial_learning_rate=5e-8,\
separate_encode=false,separate_cross_att=false,segment_embedding=false,\
input_type="single_random",adapter_type="None",num_fine_encoder_layers=0,normalization="after",\
src_lang_tok="en_XX",hyp_lang_tok="de_DE",tgt_lang_tok="de_DE",mbart_model_code="facebook/mbart-large-cc25",\
spm_path="sentence.bpe.model",pad="<pad>",bos="<s>",eos="</s>",unk="<unk>"
Multi-source finetuning
PYTHONPATH=${path_to_TRICE} \
python ${path_to_TRICE}/thumt/bin/trainer.py \
--input ${train_src1} ${train_src2} ${train_tgt} \
--vocabulary ${vocab_joint} ${vocab_joint} \
--validation ${dev_src1} ${dev_src2} \
--references ${dev_ref} \
--model transformer --half --hparam_set big \
--checkpoint single_finetuned/eval/model-best.pt \
--output multi_finetuned \
--parameters \
fixed_batch_size=false,batch_size=820,train_steps=120000,update_cycle=5,device_list=[0,1,2,3],\
keep_checkpoint_max=2,save_checkpoint_steps=2000,\
eval_steps=2001,decode_alpha=1.0,decode_batch_size=16,keep_top_checkpoint_max=1,\
attention_dropout=0.1,relu_dropout=0.1,residual_dropout=0.1,learning_rate=5e-05,warmup_steps=4000,initial_learning_rate=5e-8,special_learning_rate=5e-04,special_var_name="adapter",\
separate_encode=false,separate_cross_att=true,segment_embedding=true,\
input_type="",adapter_type="Cross-attn",num_fine_encoder_layers=1,normalization="after",\
src_lang_tok="en_XX",hyp_lang_tok="de_DE",tgt_lang_tok="de_DE",mbart_model_code="facebook/mbart-large-cc25",\
spm_path="sentence.bpe.model",pad="<pad>",bos="<s>",eos="</s>",unk="<unk>"
Arguments to be explained
** special_learning_rate
: if a variable's name contains special_var_name
, the learning rate of it will be special_learning_rate
. We give the fine encoder a larger learning rate.
** separate_encode
: whether to encode multiple sources separately before the fine encoder.
** separate_cross_att
: whether to use separated cross-attention described in our paper.
** segment_embedding
: whether to use sinusoidal segment embedding described in our paper.
** input_type
: "single_random"
for single-source finetuning , ""
for multi-source finetuning.
** adapter_type
: "None"
for no fine encoder, "Cross-attn"
for fine encoder with cross-attention.
** num_fine_encoder_layers
: number of fine encoder layers.
** src_lang_tok
: language token for the first source sentence. Please refer to here for language tokens for all 25 languages.
** hyp_lang_tok
: language token for the second source sentence.
** tgt_lang_tok
: language token for the target sentence.
** mbart_model_code
: model code for transformers.
** spm_path
: sentence piece model (can download from here).
Explanations for other arguments could be found in the user manual of THUMT.
Inference
PYTHONPATH=${path_to_TRICE} \
python ${path_to_TRICE}/thumt/bin/translator.py \
--input ${test_src1} ${test_src2} --output ${test_tgt} \
--vocabulary ${vocab_joint} ${vocab_joint} \
--checkpoints multi_finetuned/eval/model-best.pt \
--model transformer --half \
--parameters device_list=[0,1,2,3],decode_alpha=1.0,decode_batch_size=32
# recover sentence piece tokenization
...
# calculate BLEU
...
Contact
If you have questions, suggestions and bug reports, please email [email protected].