About this repository
This repo contains an Pytorch implementation for the ACL 2017 paper Get To The Point: Summarization with Pointer-Generator Networks. The code framework is based on TextBox.
Environment
python >= 3.8.11
torch >= 1.6.0
Run install.sh
to install other requirements.
Dataset
The processed dataset can be downloaded from Google Drive. Once finished, unzip the datafiles (train.src
, train.tgt
, ...) to ./data
.
An overview of dataset: train
: 287113 cases, dev
: 13368 cases, test
: 11490 cases
Paramters
# overall settings
data_path: 'data/'
checkpoint_dir: 'saved/'
generated_text_dir: 'generated/'
# dataset settings
max_vocab_size: 50000
src_len: 400
tgt_len: 100
# model settngs
decoding_strategy: 'beam_search'
beam_size: 4
is_attention: True
is_pgen: True
is_coverage: True
cov_loss_lambda: 1.0
Log file is located in ./log
, more details can be found in yamls.
Note: Distributed Data Parallel (DDP) is not supported yet.
Train & Evaluation
fire.py
.
From scratch run if __name__ == '__main__':
config = Config(config_dict={'test_only': False,
'load_experiment': None})
train(config)
If you want to resume from a checkpoint, just set the 'load_experiment': './saved/$model_name$.pth'
. Similarly, when 'test_only'
is set to True
, 'load_experiment'
is required.
Results
The best model is trained on a TITAN Xp GPU (8GB usage).
Training loss
Ablation study
Model | Rouge-1 | Rouge-2 | Rouge-L |
---|---|---|---|
Seq2Seq | 22.17 | 7.20 | 20.97 |
Seq2Seq+attn | 29.35 | 12.58 | 27.38 |
Seq2Seq+attn+pgen | 36.04 | 15.87 | 32.92 |
Seq2Seq+attn+pgen+coverage | 39.52 | 17.85 | 36.40 |
Note: The architecture of the Seq2Seq model is based on lstm
, I hope I can replace it with transformer in the future.