Source codes for "Structure-Aware Abstractive Conversation Summarization via Discourse and Action Graphs"

Overview

Structure-Aware-BART

This repo contains codes for the following paper:

Jiaao Chen, Diyi Yang:Structure-Aware Abstractive Conversation Summarization via Discourse and Action Graphs, NAACL 2021

If you would like to refer to it, please cite the paper mentioned above.

Getting Started

These instructions will get you running the codes of Structure-Aware-Bart Conversation Summarization.

Requirements

Note that different versions of rouge or different rouge packages might result in different rouge scores. For the transformers, we used the version released by Oct. 7 2020. The updated version might also result in different performances.

Install the transformers with S-BART

cd transformers

pip install --editable ./

Downloading the data

Please download the dataset (including pre-processed graphs) and put them in the data folder here

Pre-processing the data

The data folder you download from the above link already contains all the pre-processed files (including the extracted graphs) from SAMSum corpus.

Extract Discourse Graphs

Here we utilize the data and codes from here to pre-train a conversation discourse parser and use that parser to extract discourse graphs in the SAMSum dataset.

Extract Action Graphs

Please go through ./src/data/extract_actions.ipynb to extract action graphs.

Training models

These section contains instructions for training the conversation summarizationmodels.

The generated summaries on test set for baseline BART and the S-BART is in the ./src/baseline and ./src/composit folder. (trained with seed 42)

The training logs from wandb for different seed (0,1,42) for S-BART is shown in ./src/Weights&Biases.pdf

Training baseline BART model

Please run ./train_base.sh to train the BART baseline models.

Training S-BART model

Please run ./train_multi_graph.sh to train the S-BART model.

Evaluating models

Please follow the example jupyter notebook (./src/eval.ipynb) is provided for evaluating the model on test set.

Comments
  • ADSC dataset

    ADSC dataset

    Hi,

    I tried several runs of the S-BART w. Discourse&Action (last row of Table 2) on the SAMSum dataset, and the results are sometimes higher/lower than the BART baseline, which is probably caused by random seeds. Nevertheless, the performance gain on the ADSC dataset seems to be more substantial according to Table 3. I was wondering if it's possible to also release the code to test the model on that dataset?

    Thanks.

    opened by chijames 8
  • How to batch training samples?

    How to batch training samples?

    Hello, Thanks for the nice paper and code! I notice that the decoder attends to every utterances after bart encoder. So I am wondering how to batch these, since the input to the model isList[torch.Tesnor] where len(list)=batch_size, torch.Tensor.shape=[num_utts_in_one_conversation, seq_len]?

    In your code

    loss_tensors = self._step(batch) # in SummarizationModule/train_step
    outputs = self(src_ids, attention_mask=src_mask, decoder_input_ids=decoder_input_ids, use_cache=False, discourse_graph = self.hparams.discourse_graph, adj = adj, segmented_encoder = self.hparams.segmented_encoder, relation = self.hparams.relation, action_adj = action_adj, actions = actions, actions_mask = actions_mask,) # in SummarizationModule/_step
    

    But it seems that BaseTransformer does not give forward ortrain_step or__call__ function? Thanks for your help

    opened by Hannibal046 3
  • Training multi-model

    Training multi-model

    @jiaaoc Used versions: torch==1.3.0 pytorch-lightning==0.8.1 install rouge==1.0.0 transformers==3.2.0 --editable 'transformers/' Running bash script train_multi_graph.sh throws following error:

    GPU available: True, used: True
    INFO:lightning:GPU available: True, used: True
    TPU available: False, using: 0 TPU cores
    INFO:lightning:TPU available: False, using: 0 TPU cores
    CUDA_VISIBLE_DEVICES: [0]
    INFO:lightning:CUDA_VISIBLE_DEVICES: [0]
    normal graph
    dicsource_graph
    Traceback (most recent call last):
      File "/Structure-Aware-BART-main/src/train.py", line 540, in <module>
        main(args)
      File "/Structure-Aware-BART-main/src/train.py", line 506, in main
        logger=logger,
      File "/Structure-Aware-BART-main/src/lightning_base.py", line 700, in generic_train
        trainer.fit(model)
      File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/trainer.py", line 918, in fit
        self.single_gpu_train(model)
      File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/distrib_parts.py", line 167, in single_gpu_train
        self.optimizers, self.lr_schedulers, self.optimizer_frequencies = self.init_optimizers(model)
      File "/usr/local/lib/python3.7/dist-packages/pytorch_lightning/trainer/optimizers.py", line 18, in init_optimizers
        optim_conf = model.configure_optimizers()
      File "/Structure-Aware-BART-main/src/lightning_base.py", line 181, in configure_optimizers
        new_params_id += list(map(id, model.model.discourse_encoder.parameters()))  +\
      File "/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py", line 585, in __getattr__
        type(self).__name__, name))
    AttributeError: 'BartModel' object has no attribute 'discourse_encoder'
    
    opened by atharvou23 2
  • Error when loading data

    Error when loading data

    Bash script train_base.sh throws the following error:

    Global seed set to 42 Traceback (most recent call last): File "train.py", line 540, in <module> main(args) File "train.py", line 502, in main checkpoint_callback=get_checkpoint_callback( File "/export/home/dialogue-sum/sa-bart/src/callbacks.py", line 100, in get_checkpoint_callback checkpoint_callback = ModelCheckpoint( TypeError: __init__() got an unexpected keyword argument 'filepath'

    Could be caused by an inconsistent pytorch_lightning version. The README does not specify which version the authors used.

    opened by muggin 2
  • error when run train_base.sh

    error when run train_base.sh

    when i run './train_base.sh'. I got the following error. Do you have any idea?

    Traceback (most recent call last): File "train.py", line 543, in main(args) File "train.py", line 523, in main trainer.test() File "/home/wxu/anaconda3/envs/bart/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 712, in test results = self.__test_using_best_weights(ckpt_path, test_dataloaders) File "/home/wxu/anaconda3/envs/bart/lib/python3.7/site-packages/pytorch_lightning/trainer/trainer.py", line 724, in __test_using_best_weights 'ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.' pytorch_lightning.utilities.exceptions.MisconfigurationException: ckpt_path is "best", but ModelCheckpoint is not configured to save the best model.

    opened by xwjim 1
  • BART-Large

    BART-Large

    Hi,

    I notice that in one earlier paper of yours, the decoder is BART-large, which leads to much better performance. Have you tried using BART-large in this paper? If not, why?

    Thanks.

    opened by chijames 1
  • Warning when loading multi-graph model

    Warning when loading multi-graph model

    Script throws the following warning when running multi_graph training.

    Some weights of BartForConditionalGeneration were not initialized from the model checkpoint at facebook/bart-base and are newly initialized: ['model.decoder.layers.0.resweight', 'model.decoder.layers.0.resweight_2', 'model.decoder.layers.0.discourse_attn.k_proj.weight', 'model.decoder.layers.0.discourse_attn.k_proj.bias', 'model.decoder.layers.0.discourse_attn.v_proj.weight', 'model.decoder.layers.0.discourse_attn.v_proj.bias', 'model.decoder.layers.0.discourse_attn.q_proj.weight', 'model.decoder.layers.0.discourse_attn.q_proj.bias', 'model.decoder.layers.0.discourse_attn.out_proj.weight', 'model.decoder.layers.0.discourse_attn.out_proj.bias', 'model.decoder.layers.0.discourse_attn_layer_norm.weight', 'model.decoder.layers.0.discourse_attn_layer_norm.bias', 'model.decoder.layers.0.action_attn.k_proj.weight', 'model.decoder.layers.0.action_attn.k_proj.bias', 'model.decoder.layers.0.action_attn.v_proj.weight', 'model.decoder.layers.0.action_attn.v_proj.bias', 'model.decoder.layers.0.action_attn.q_proj.weight', 'model.decoder.layers.0.action_attn.q_proj.bias', 'model.decoder.layers.0.action_attn.out_proj.weight', 'model.decoder.layers.0.action_attn.out_proj.bias', 'model.decoder.layers.0.action_attn_layer_norm.weight', 'model.decoder.layers.0.action_attn_layer_norm.bias', 'model.decoder.layers.0.composit_layer.weight', 'model.decoder.layers.0.composit_layer.bias', 'model.decoder.layers.0.composit_layer_norm.weight', 'model.decoder.layers.0.composit_layer_norm.bias', 'model.decoder.layers.1.resweight', 'model.decoder.layers.1.resweight_2', 'model.decoder.layers.1.discourse_attn.k_proj.weight', 'model.decoder.layers.1.discourse_attn.k_proj.bias', 'model.decoder.layers.1.discourse_attn.v_proj.weight', 'model.decoder.layers.1.discourse_attn.v_proj.bias', 'model.decoder.layers.1.discourse_attn.q_proj.weight', 'model.decoder.layers.1.discourse_attn.q_proj.bias', 'model.decoder.layers.1.discourse_attn.out_proj.weight', 'model.decoder.layers.1.discourse_attn.out_proj.bias', 'model.decoder.layers.1.discourse_attn_layer_norm.weight', 'model.decoder.layers.1.discourse_attn_layer_norm.bias', 'model.decoder.layers.1.action_attn.k_proj.weight', 'model.decoder.layers.1.action_attn.k_proj.bias', 'model.decoder.layers.1.action_attn.v_proj.weight', 'model.decoder.layers.1.action_attn.v_proj.bias', 'model.decoder.layers.1.action_attn.q_proj.weight', 'model.decoder.layers.1.action_attn.q_proj.bias', 'model.decoder.layers.1.action_attn.out_proj.weight', 'model.decoder.layers.1.action_attn.out_proj.bias', 'model.decoder.layers.1.action_attn_layer_norm.weight', 'model.decoder.layers.1.action_attn_layer_norm.bias', 'model.decoder.layers.1.composit_layer.weight', 'model.decoder.layers.1.composit_layer.bias', 'model.decoder.layers.1.composit_layer_norm.weight', 'model.decoder.layers.1.composit_layer_norm.bias', 'model.decoder.layers.2.resweight', 'model.decoder.layers.2.resweight_2', 'model.decoder.layers.2.discourse_attn.k_proj.weight', 'model.decoder.layers.2.discourse_attn.k_proj.bias', 'model.decoder.layers.2.discourse_attn.v_proj.weight', 'model.decoder.layers.2.discourse_attn.v_proj.bias', 'model.decoder.layers.2.discourse_attn.q_proj.weight', 'model.decoder.layers.2.discourse_attn.q_proj.bias', 'model.decoder.layers.2.discourse_attn.out_proj.weight', 'model.decoder.layers.2.discourse_attn.out_proj.bias', 'model.decoder.layers.2.discourse_attn_layer_norm.weight', 'model.decoder.layers.2.discourse_attn_layer_norm.bias', 'model.decoder.layers.2.action_attn.k_proj.weight', 'model.decoder.layers.2.action_attn.k_proj.bias', 'model.decoder.layers.2.action_attn.v_proj.weight', 'model.decoder.layers.2.action_attn.v_proj.bias', 'model.decoder.layers.2.action_attn.q_proj.weight', 'model.decoder.layers.2.action_attn.q_proj.bias', 'model.decoder.layers.2.action_attn.out_proj.weight', 'model.decoder.layers.2.action_attn.out_proj.bias', 'model.decoder.layers.2.action_attn_layer_norm.weight', 'model.decoder.layers.2.action_attn_layer_norm.bias', 'model.decoder.layers.2.composit_layer.weight', 'model.decoder.layers.2.composit_layer.bias', 'model.decoder.layers.2.composit_layer_norm.weight', 'model.decoder.layers.2.composit_layer_norm.bias', 'model.decoder.layers.3.resweight', 'model.decoder.layers.3.resweight_2', 'model.decoder.layers.3.discourse_attn.k_proj.weight', 'model.decoder.layers.3.discourse_attn.k_proj.bias', 'model.decoder.layers.3.discourse_attn.v_proj.weight', 'model.decoder.layers.3.discourse_attn.v_proj.bias', 'model.decoder.layers.3.discourse_attn.q_proj.weight', 'model.decoder.layers.3.discourse_attn.q_proj.bias', 'model.decoder.layers.3.discourse_attn.out_proj.weight', 'model.decoder.layers.3.discourse_attn.out_proj.bias', 'model.decoder.layers.3.discourse_attn_layer_norm.weight', 'model.decoder.layers.3.discourse_attn_layer_norm.bias', 'model.decoder.layers.3.action_attn.k_proj.weight', 'model.decoder.layers.3.action_attn.k_proj.bias', 'model.decoder.layers.3.action_attn.v_proj.weight', 'model.decoder.layers.3.action_attn.v_proj.bias', 'model.decoder.layers.3.action_attn.q_proj.weight', 'model.decoder.layers.3.action_attn.q_proj.bias', 'model.decoder.layers.3.action_attn.out_proj.weight', 'model.decoder.layers.3.action_attn.out_proj.bias', 'model.decoder.layers.3.action_attn_layer_norm.weight', 'model.decoder.layers.3.action_attn_layer_norm.bias', 'model.decoder.layers.3.composit_layer.weight', 'model.decoder.layers.3.composit_layer.bias', 'model.decoder.layers.3.composit_layer_norm.weight', 'model.decoder.layers.3.composit_layer_norm.bias', 'model.decoder.layers.4.resweight', 'model.decoder.layers.4.resweight_2', 'model.decoder.layers.4.discourse_attn.k_proj.weight', 'model.decoder.layers.4.discourse_attn.k_proj.bias', 'model.decoder.layers.4.discourse_attn.v_proj.weight', 'model.decoder.layers.4.discourse_attn.v_proj.bias', 'model.decoder.layers.4.discourse_attn.q_proj.weight', 'model.decoder.layers.4.discourse_attn.q_proj.bias', 'model.decoder.layers.4.discourse_attn.out_proj.weight', 'model.decoder.layers.4.discourse_attn.out_proj.bias', 'model.decoder.layers.4.discourse_attn_layer_norm.weight', 'model.decoder.layers.4.discourse_attn_layer_norm.bias', 'model.decoder.layers.4.action_attn.k_proj.weight', 'model.decoder.layers.4.action_attn.k_proj.bias', 'model.decoder.layers.4.action_attn.v_proj.weight', 'model.decoder.layers.4.action_attn.v_proj.bias', 'model.decoder.layers.4.action_attn.q_proj.weight', 'model.decoder.layers.4.action_attn.q_proj.bias', 'model.decoder.layers.4.action_attn.out_proj.weight', 'model.decoder.layers.4.action_attn.out_proj.bias', 'model.decoder.layers.4.action_attn_layer_norm.weight', 'model.decoder.layers.4.action_attn_layer_norm.bias', 'model.decoder.layers.4.composit_layer.weight', 'model.decoder.layers.4.composit_layer.bias', 'model.decoder.layers.4.composit_layer_norm.weight', 'model.decoder.layers.4.composit_layer_norm.bias', 'model.decoder.layers.5.resweight', 'model.decoder.layers.5.resweight_2', 'model.decoder.layers.5.discourse_attn.k_proj.weight', 'model.decoder.layers.5.discourse_attn.k_proj.bias', 'model.decoder.layers.5.discourse_attn.v_proj.weight', 'model.decoder.layers.5.discourse_attn.v_proj.bias', 'model.decoder.layers.5.discourse_attn.q_proj.weight', 'model.decoder.layers.5.discourse_attn.q_proj.bias', 'model.decoder.layers.5.discourse_attn.out_proj.weight', 'model.decoder.layers.5.discourse_attn.out_proj.bias', 'model.decoder.layers.5.discourse_attn_layer_norm.weight', 'model.decoder.layers.5.discourse_attn_layer_norm.bias', 'model.decoder.layers.5.action_attn.k_proj.weight', 'model.decoder.layers.5.action_attn.k_proj.bias', 'model.decoder.layers.5.action_attn.v_proj.weight', 'model.decoder.layers.5.action_attn.v_proj.bias', 'model.decoder.layers.5.action_attn.q_proj.weight', 'model.decoder.layers.5.action_attn.q_proj.bias', 'model.decoder.layers.5.action_attn.out_proj.weight', 'model.decoder.layers.5.action_attn.out_proj.bias', 'model.decoder.layers.5.action_attn_layer_norm.weight', 'model.decoder.layers.5.action_attn_layer_norm.bias', 'model.decoder.layers.5.composit_layer.weight', 'model.decoder.layers.5.composit_layer.bias', 'model.decoder.layers.5.composit_layer_norm.weight', 'model.decoder.layers.5.composit_layer_norm.bias', 'model.discourse_encoder.attention_0.W', 'model.discourse_encoder.attention_0.a', 'model.discourse_encoder.attention_0.one_hot_embedding.weight', 'model.discourse_encoder.attention_0.layer_norm.weight', 'model.discourse_encoder.attention_0.layer_norm.bias', 'model.discourse_encoder.attention_1.W', 'model.discourse_encoder.attention_1.a', 'model.discourse_encoder.attention_1.one_hot_embedding.weight', 'model.discourse_encoder.attention_1.layer_norm.weight', 'model.discourse_encoder.attention_1.layer_norm.bias', 'model.discourse_encoder.out_att.W', 'model.discourse_encoder.out_att.a', 'model.discourse_encoder.out_att.one_hot_embedding.weight', 'model.discourse_encoder.out_att.layer_norm.weight', 'model.discourse_encoder.out_att.layer_norm.bias', 'model.discourse_encoder.fc.weight', 'model.discourse_encoder.fc.bias', 'model.discourse_encoder.layer_norm.weight', 'model.discourse_encoder.layer_norm.bias', 'model.action_encoder.attention_0.W', 'model.action_encoder.attention_0.a', 'model.action_encoder.attention_0.layer_norm.weight', 'model.action_encoder.attention_0.layer_norm.bias', 'model.action_encoder.attention_1.W', 'model.action_encoder.attention_1.a', 'model.action_encoder.attention_1.layer_norm.weight', 'model.action_encoder.attention_1.layer_norm.bias', 'model.action_encoder.out_att.W', 'model.action_encoder.out_att.a', 'model.action_encoder.out_att.layer_norm.weight', 'model.action_encoder.out_att.layer_norm.bias', 'model.action_encoder.fc.weight', 'model.action_encoder.fc.bias', 'model.action_encoder.layer_norm.weight', 'model.action_encoder.layer_norm.bias'] You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.

    Is this expected?

    opened by muggin 1
Owner
GT-SALT
Social and Language Technologies Lab
GT-SALT
Source codes of CenterTrack++ in 2021 ICME Workshop on Big Surveillance Data Processing and Analysis

MOT Tracked object bounding box association (CenterTrack++) New association method based on CenterTrack. Two new branches (Tracked Size and IOU) are a

null 36 Oct 4, 2022
The source codes for ACL 2021 paper 'BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data'

BoB: BERT Over BERT for Training Persona-based Dialogue Models from Limited Personalized Data This repository provides the implementation details for

null 124 Dec 27, 2022
Multiple paper open-source codes of the Microsoft Research Asia DKI group

?? Paper Code Collection (MSRA DKI Group) This repo hosts multiple open-source codes of the Microsoft Research Asia DKI Group. You could find the corr

Microsoft 249 Jan 8, 2023
A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''.

P-tuning A novel method to tune language models. Codes and datasets for paper ``GPT understands, too''. How to use our code We have released the code

THUDM 562 Dec 27, 2022
Codes for NAACL 2021 Paper "Unsupervised Multi-hop Question Answering by Question Generation"

Unsupervised-Multi-hop-QA This repository contains code and models for the paper: Unsupervised Multi-hop Question Answering by Question Generation (NA

Liangming Pan 70 Nov 27, 2022
This is my codes that can visualize the psnr image in testing videos.

CVPR2018-Baseline-PSNRplot This is my codes that can visualize the psnr image in testing videos. Future Frame Prediction for Anomaly Detection – A New

Wenhao Yang 12 May 29, 2021
codes for Image Inpainting with External-internal Learning and Monochromic Bottleneck

Image Inpainting with External-internal Learning and Monochromic Bottleneck This repository is for the CVPR 2021 paper: 'Image Inpainting with Externa

null 97 Nov 29, 2022
Codes for our paper "SentiLARE: Sentiment-Aware Language Representation Learning with Linguistic Knowledge" (EMNLP 2020)

SentiLARE: Sentiment-Aware Language Representation Learning with Linguistic Knowledge Introduction SentiLARE is a sentiment-aware pre-trained language

null 74 Dec 30, 2022
Python codes for Lite Audio-Visual Speech Enhancement.

Lite Audio-Visual Speech Enhancement (Interspeech 2020) Introduction This is the PyTorch implementation of Lite Audio-Visual Speech Enhancement (LAVSE

Shang-Yi Chuang 85 Dec 1, 2022
Codes for our IJCAI21 paper: Dialogue Discourse-Aware Graph Model and Data Augmentation for Meeting Summarization

DDAMS This is the pytorch code for our IJCAI 2021 paper Dialogue Discourse-Aware Graph Model and Data Augmentation for Meeting Summarization [Arxiv Pr

xcfeng 55 Dec 27, 2022
Official codes for the paper "Learning Hierarchical Discrete Linguistic Units from Visually-Grounded Speech"

ResDAVEnet-VQ Official PyTorch implementation of Learning Hierarchical Discrete Linguistic Units from Visually-Grounded Speech What is in this repo? M

Wei-Ning Hsu 21 Aug 23, 2022
Pytorch codes for "Self-supervised Multi-view Stereo via Effective Co-Segmentation and Data-Augmentation"

Self-Supervised-MVS This repository is the official PyTorch implementation of our AAAI 2021 paper: "Self-supervised Multi-view Stereo via Effective Co

hongbin_xu 127 Jan 4, 2023
Codes for ACL-IJCNLP 2021 Paper "Zero-shot Fact Verification by Claim Generation"

Zero-shot-Fact-Verification-by-Claim-Generation This repository contains code and models for the paper: Zero-shot Fact Verification by Claim Generatio

Liangming Pan 47 Jan 1, 2023
The official codes of "Semi-supervised Models are Strong Unsupervised Domain Adaptation Learners".

SSL models are Strong UDA learners Introduction This is the official code of paper "Semi-supervised Models are Strong Unsupervised Domain Adaptation L

Yabin Zhang 26 Dec 26, 2022
The codes and models in 'Gaze Estimation using Transformer'.

GazeTR We provide the code of GazeTR-Hybrid in "Gaze Estimation using Transformer". We recommend you to use data processing codes provided in GazeHub.

null 65 Dec 27, 2022
codes for paper Combining Dynamic Local Context Focus and Dependency Cluster Attention for Aspect-level sentiment classification

DLCF-DCA codes for paper Combining Dynamic Local Context Focus and Dependency Cluster Attention for Aspect-level sentiment classification. submitted t

null 15 Aug 30, 2022
The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"

Swin-Unet The codes for the work "Swin-Unet: Unet-like Pure Transformer for Medical Image Segmentation"(https://arxiv.org/abs/2105.05537). A validatio

null 869 Jan 7, 2023
Codes for paper "Towards Diverse Paragraph Captioning for Untrimmed Videos". CVPR 2021

Towards Diverse Paragraph Captioning for Untrimmed Videos This repository contains PyTorch implementation of our paper Towards Diverse Paragraph Capti

Yuqing Song 61 Oct 11, 2022