Repo for Enhanced Seq2Seq Autoencoder via Contrastive Learning for Abstractive Text Summarization

Overview

ESACL: Enhanced Seq2Seq Autoencoder via Contrastive Learning for AbstractiveText Summarization

This repo is for our paper "Enhanced Seq2Seq Autoencoder via Contrastive Learning for AbstractiveText Summarization". Our program is building on top of the Huggingface transformers framework. You can refer to their repo at: https://github.com/huggingface/transformers/tree/master/examples/seq2seq.

Local Setup

Tested with Python 3.7 via virtual environment. Clone the repo, go to the repo folder, setup the virtual environment, and install the required packages:

$ python3.7 -m venv venv
$ source venv/bin/activate
$ pip install -r requirements.txt

Install apex

Based on the recommendation from HuggingFace, both finetuning and eval are 30% faster with --fp16. For that you need to install apex.

$ git clone https://github.com/NVIDIA/apex
$ cd apex
$ pip install -v --disable-pip-version-check --no-cache-dir --global-option="--cpp_ext" --global-option="--cuda_ext" ./

Data

Create a directory for data used in this work named data:

$ mkdir data

CNN/DM

$ wget https://cdn-datasets.huggingface.co/summarization/cnn_dm_v2.tgz
$ tar -xzvf cnn_dm_v2.tgz
$ mv cnn_cln data/cnndm

XSUM

$ wget https://cdn-datasets.huggingface.co/summarization/xsum.tar.gz
$ tar -xzvf xsum.tar.gz
$ mv xsum data/xsum

Generate Augmented Dataset

$ python generate_augmentation.py \
    --dataset xsum \
    --n 5 \
    --augmentation1 randomdelete \
    --augmentation2 randomswap

Training

CNN/DM

Our model is warmed up using sshleifer/distilbart-cnn-12-6:

$ DATA_DIR=./data/cnndm-augmented/RandominsertionRandominsertion-NumSent-3
$ OUTPUT_DIR=./log/cnndm

$ python -m torch.distributed.launch --nproc_per_node=3  cl_finetune_trainer.py \
  --data_dir $DATA_DIR \
  --output_dir $OUTPUT_DIR \
  --learning_rate=5e-7 \
  --per_device_train_batch_size 16 \
  --per_device_eval_batch_size 16 \
  --do_train --do_eval \
  --evaluation_strategy steps \
  --freeze_embeds \
  --save_total_limit 10 \
  --save_steps 1000 \
  --logging_steps 1000 \
  --num_train_epochs 5 \
  --model_name_or_path sshleifer/distilbart-cnn-12-6 \
  --alpha 0.2 \
  --temperature 0.5 \
  --freeze_encoder_layer 6 \
  --prediction_loss_only \
  --fp16

XSUM

$ DATA_DIR=./data/xsum-augmented/RandomdeleteRandomswap-NumSent-3
$ OUTPUT_DIR=./log/xsum

$ python -m torch.distributed.launch --nproc_per_node=3  cl_finetune_trainer.py \
  --data_dir $DATA_DIR \
  --output_dir $OUTPUT_DIR \
  --learning_rate=5e-7 \
  --per_device_train_batch_size 16 \
  --per_device_eval_batch_size 16 \
  --do_train --do_eval \
  --evaluation_strategy steps \
  --freeze_embeds \
  --save_total_limit 10 \
  --save_steps 1000 \
  --logging_steps 1000 \
  --num_train_epochs 5 \
  --model_name_or_path sshleifer/distilbart-xsum-12-6 \
  --alpha 0.2 \
  --temperature 0.5 \
  --freeze_encoder \
  --prediction_loss_only \
  --fp16

Evaluation

We have released the following checkpoints for pre-trained models as described in the paper:

CNN/DM

CNN/DM requires an extra postprocessing step.

$ export DATA=cnndm
$ export DATA_DIR=data/$DATA
$ export CHECKPOINT_DIR=./log/$DATA
$ export OUTPUT_DIR=output/$DATA

$ python -m torch.distributed.launch --nproc_per_node=2  run_distributed_eval.py \
    --model_name sshleifer/distilbart-cnn-12-6  \
    --save_dir $OUTPUT_DIR \
    --data_dir $DATA_DIR \
    --bs 16 \
    --fp16 \
    --use_checkpoint \
    --checkpoint_path $CHECKPOINT_DIR
    
$ python postprocess_cnndm.py \
    --src_file $OUTPUT_DIR/test_generations.txt \
    --tgt_file $DATA_DIR/test.target

XSUM

$ export DATA=xsum
$ export DATA_DIR=data/$DATA
$ export CHECKPOINT_DIR=./log/$DATA
$ export OUTPUT_DIR=output/$DATA

$ python -m torch.distributed.launch --nproc_per_node=3  run_distributed_eval.py \
    --model_name sshleifer/distilbart-xsum-12-6  \
    --save_dir $OUTPUT_DIR \
    --data_dir $DATA_DIR \
    --bs 16 \
    --fp16 \
    --use_checkpoint \
    --checkpoint_path $CHECKPOINT_DIR
You might also like...
Python implementation of TextRank for phrase extraction and summarization of text documents
Python implementation of TextRank for phrase extraction and summarization of text documents

PyTextRank PyTextRank is a Python implementation of TextRank as a spaCy pipeline extension, used to: extract the top-ranked phrases from text document

Summarization, translation, sentiment-analysis, text-generation and more at blazing speed using a T5 version implemented in ONNX.
Summarization, translation, sentiment-analysis, text-generation and more at blazing speed using a T5 version implemented in ONNX.

Summarization, translation, Q&A, text generation and more at blazing speed using a T5 version implemented in ONNX. This package is still in alpha stag

Module for automatic summarization of text documents and HTML pages.

Automatic text summarizer Simple library and command line utility for extracting summary from HTML pages or plain texts. The package also contains sim

Python implementation of TextRank for phrase extraction and summarization of text documents
Python implementation of TextRank for phrase extraction and summarization of text documents

PyTextRank PyTextRank is a Python implementation of TextRank as a spaCy pipeline extension, used to: extract the top-ranked phrases from text document

Summarization, translation, sentiment-analysis, text-generation and more at blazing speed using a T5 version implemented in ONNX.
Summarization, translation, sentiment-analysis, text-generation and more at blazing speed using a T5 version implemented in ONNX.

Summarization, translation, Q&A, text generation and more at blazing speed using a T5 version implemented in ONNX. This package is still in alpha stag

The guide to tackle with the Text Summarization
The guide to tackle with the Text Summarization

The guide to tackle with the Text Summarization

 SummerTime - Text Summarization Toolkit for Non-experts
SummerTime - Text Summarization Toolkit for Non-experts

A library to help users choose appropriate summarization tools based on their specific tasks or needs. Includes models, evaluation metrics, and datasets.

Deploying a Text Summarization NLP use case on Docker Container Utilizing Nvidia GPU
Deploying a Text Summarization NLP use case on Docker Container Utilizing Nvidia GPU

GPU Docker NLP Application Deployment Deploying a Text Summarization NLP use case on Docker Container Utilizing Nvidia GPU, to setup the enviroment on

Two-stage text summarization with BERT and BART
Two-stage text summarization with BERT and BART

Two-Stage Text Summarization Description We experiment with a 2-stage summarization model on CNN/DailyMail dataset that combines the ability to filter

Comments
  • Run training with MBart error

    Run training with MBart error

    Can u fix when i use pretrained with MBart ? Traceback (most recent call last): File "cl_finetune_trainer.py", line 384, in <module> main() File "cl_finetune_trainer.py", line 321, in main trainer.train( File "/data/dodx/esacl/cl_seq2seq_trainer.py", line 737, in train for step, inputs in enumerate(epoch_iterator): File "/data/dodx/anaconda3/envs/esacl/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 521, in __next__ data = self._next_data() File "/data/dodx/anaconda3/envs/esacl/lib/python3.8/site-packages/torch/utils/data/dataloader.py", line 561, in _next_data data = self._dataset_fetcher.fetch(index) # may raise StopIteration File "/data/dodx/anaconda3/envs/esacl/lib/python3.8/site-packages/torch/utils/data/_utils/fetch.py", line 52, in fetch return self.collate_fn(data) File "/data/dodx/esacl/utils.py", line 295, in __call__ batch = self._encode(batch) File "/data/dodx/esacl/utils.py", line 330, in _encode batch_encoding = self.tokenizer.prepare_seq2seq_batch( File "/data/dodx/anaconda3/envs/esacl/lib/python3.8/site-packages/transformers/tokenization_utils.py", line 813, in prepare_seq2seq_batch raise NotImplementedError( NotImplementedError: If your model requires more than input_ids for a typical forward pass, you should implement this method. Returned keys should be [input_ids, attention_mask, labels]. See MarianTokenizer or T5Tokenizer for a reference implementation. Can u fix

    opened by batman-do 0
Owner
Rachel Zheng
Rachel Zheng
Rachel Zheng
FactSumm: Factual Consistency Scorer for Abstractive Summarization

FactSumm: Factual Consistency Scorer for Abstractive Summarization FactSumm is a toolkit that scores Factualy Consistency for Abstract Summarization W

devfon 83 Jan 9, 2023
null 189 Jan 2, 2023
(ACL 2022) The source code for the paper "Towards Abstractive Grounded Summarization of Podcast Transcripts"

Towards Abstractive Grounded Summarization of Podcast Transcripts We provide the source code for the paper "Towards Abstractive Grounded Summarization

null 10 Jul 1, 2022
multi-label,classifier,text classification,多标签文本分类,文本分类,BERT,ALBERT,multi-label-classification,seq2seq,attention,beam search

multi-label,classifier,text classification,多标签文本分类,文本分类,BERT,ALBERT,multi-label-classification,seq2seq,attention,beam search

hellonlp 30 Dec 12, 2022
🏖 Easy training and deployment of seq2seq models.

Headliner Headliner is a sequence modeling library that eases the training and in particular, the deployment of custom sequence models for both resear

Axel Springer Ideas Engineering GmbH 231 Nov 18, 2022
🏖 Easy training and deployment of seq2seq models.

Headliner Headliner is a sequence modeling library that eases the training and in particular, the deployment of custom sequence models for both resear

Axel Springer Ideas Engineering GmbH 220 Feb 10, 2021
An open source framework for seq2seq models in PyTorch.

pytorch-seq2seq Documentation This is a framework for sequence-to-sequence (seq2seq) models implemented in PyTorch. The framework has modularized and

International Business Machines 1.4k Jan 2, 2023
Intent parsing and slot filling in PyTorch with seq2seq + attention

PyTorch Seq2Seq Intent Parsing Reframing intent parsing as a human - machine translation task. Work in progress successor to torch-seq2seq-intent-pars

Sean Robertson 159 Apr 4, 2022
Module for automatic summarization of text documents and HTML pages.

Automatic text summarizer Simple library and command line utility for extracting summary from HTML pages or plain texts. The package also contains sim

Mišo Belica 3k Jan 8, 2023