This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Related tags

Deep Learning VDA
Overview

Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models

This repository contains the code for our paper VDA (public in EMNLP2021 main conference)

Quick Links

Overview

We propose a general framework Virtual Data Augmentation (VDA) for robustly fine-tuning Pre-trained Language Models for downstream tasks. Our VDA utilizes a masked language model with Gaussian noise to augment virtual examples for improving the robustness, and also adopts regularized training to further guarantee the semantic relevance and diversity.

Train VDA

In the following section, we describe how to train a model with VDA by using our code.

Training

Data

For evaluation of our VDA, we use 6 text classification datasets, i.e. Yelp, IMDB, AGNews, MR, QNLI and MRPC datasets. These datasets can be downloaded from the GoogleDisk

After download the two ziped files, users should unzip the data fold that contains the training, validation and test data of the 6 datasets. While the Robust fold contains the examples for test the robustness.

Training scripts We public our VDA with 4 base models. For single sentence classification tasks, we use text_classifier_xxx.py files. While for sentence pair classification tasks, we use text_pair_classifier_xxx.py:

  • text_classifier.py and text_pair_classifier.py: BERT-base+VDA

  • text_classifier_freelb.py and text_pair_classifier_freelb.py: FreeLB+VDA on BERT-base

  • text_classifier_smart.py and text_pair_classifier_smart.py: SMART+VDA on BERT-base, where we only use the smooth-inducing adversarial regularization.

  • text_classifier_smix.py and text_pair_classifier_smix.py: Smix+VDA on BERT-base, where we remove the adversarial data augmentation for fair comparison

We provide example scripts for both training and test of our VDA on the 6 datasets. In run_train.sh, we provide 6 example for training on the yelp and qnli datasets. This script calls text_classifier_xxx.py for training (xxx refers to the base model). We explain the arguments in following:

  • --dataset: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --save_path: Saved fine-tuned checkpoints file.
  • --max_length: Max sequence length. (For Yelp/IMDB/AG, we use 512. While for MR/QNLI/MRPC, we use 256.)
  • --max_epoch: The maximum training epoch number. (In most of datasets and models, we use 10.)
  • --batch_size: The batch size. (We adapt the batch size to the maximum number w.r.t the GPU memory size. Note that too small number may cause model collapse.)
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)
  • --lr: Learning rate.
  • --num_warmup: The rate of warm-up steps.
  • --variance: The variance of the Gaussian noise.

For results in the paper, we use Nvidia Tesla V100 32G and Nvidia 3090 24G GPUs to train our models. Using different types of devices or different versions of CUDA/other softwares may lead to slightly different performance.

Evaluation

During training, our model file will show the original accuracy on the test set of the 6 datasets, which evaluates the accuracy performance of our model. Our evaluation code for robustness is based on a modified version of BERT-Attack. It outputs Attack Accuracy, Query Numbers and Perturbation Ratio metrics.

Before evaluation, please download the evaluation datasets for Robustness from the GoogleDisk. Then, following the commonly-used settings, users need to download and process consine similarity matrix following TextFooler.

Based on the checkpoint of the fine-tuned models, we use therun_test.sh script for test the robustness on yelp and qnli datasets. It is based on bert_robust.py file. We explain the arguments in following:

  • --data_path: Training file path.
  • --mlm_path: Pre-trained checkpoints to start with. For now we support BERT-based models (bert-base-uncased, bert-large-uncased, etc.)
  • --tgt_path: The fine-tuned checkpoints file.
  • --num_label: The number of labels. (For AG, we use 4. While for other, we use 2.)

which is expected to output the results as:

original accuracy is 0.960000, attack accuracy is 0.533333, query num is 687.680556, perturb rate is 0.177204

Citation

Please cite our paper if you use VDA in your work:

@inproceedings{zhou2021vda,
  author    = {Kun Zhou, Wayne Xin Zhao, Sirui Wang, Fuzheng Zhang, Wei Wu and Ji-Rong Wen},
  title     = {Virtual Data Augmentation: A Robust and General Framework for Fine-tuning Pre-trained Models},
  booktitle = {{EMNLP} 2021},
  publisher = {The Association for Computational Linguistics},
}
You might also like...
Public repository of the 3DV 2021 paper "Generative Zero-Shot Learning for Semantic Segmentation of 3D Point Clouds"

Generative Zero-Shot Learning for Semantic Segmentation of 3D Point Clouds Björn Michele1), Alexandre Boulch1), Gilles Puy1), Maxime Bucher1) and Rena

Official public repository of paper
Official public repository of paper "Intention Adaptive Graph Neural Network for Category-Aware Session-Based Recommendation"

Intention Adaptive Graph Neural Network (IAGNN) This is the official repository of paper Intention Adaptive Graph Neural Network for Category-Aware Se

This repository contains the code for the paper
This repository contains the code for the paper "Hierarchical Motion Understanding via Motion Programs"

Hierarchical Motion Understanding via Motion Programs (CVPR 2021) This repository contains the official implementation of: Hierarchical Motion Underst

This repository contains the source code and data for reproducing results of Deep Continuous Clustering paper
This repository contains the source code and data for reproducing results of Deep Continuous Clustering paper

Deep Continuous Clustering Introduction This is a Pytorch implementation of the DCC algorithms presented in the following paper (paper): Sohil Atul Sh

This repository contains a re-implementation of the code for the CVPR 2021 paper
This repository contains a re-implementation of the code for the CVPR 2021 paper "Omnimatte: Associating Objects and Their Effects in Video."

Omnimatte in PyTorch This repository contains a re-implementation of the code for the CVPR 2021 paper "Omnimatte: Associating Objects and Their Effect

This repository contains the source code for the paper
This repository contains the source code for the paper "DONeRF: Towards Real-Time Rendering of Compact Neural Radiance Fields using Depth Oracle Networks",

DONeRF: Towards Real-Time Rendering of Compact Neural Radiance Fields using Depth Oracle Networks Project Page | Video | Presentation | Paper | Data L

This repository contains the code and models for the following paper.
This repository contains the code and models for the following paper.

DC-ShadowNet Introduction This is an implementation of the following paper DC-ShadowNet: Single-Image Hard and Soft Shadow Removal Using Unsupervised

This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Information Maximization for Multimodal Sentiment Analysis, accepted at EMNLP 2021.
This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Information Maximization for Multimodal Sentiment Analysis, accepted at EMNLP 2021.

MultiModal-InfoMax This repository contains the official implementation code of the paper Improving Multimodal Fusion with Hierarchical Mutual Informa

Owner
RUCAIBox
An enthusiastic group that aims to create beautiful things with AI
RUCAIBox
Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP2021)

TDEER (WIP) Code For TDEER: An Efficient Translating Decoding Schema for Joint Extraction of Entities and Relations (EMNLP2021) Overview TDEER is an e

Alipay 6 Dec 17, 2022
This repo is the code release of EMNLP 2021 conference paper "Connect-the-Dots: Bridging Semantics between Words and Definitions via Aligning Word Sense Inventories".

Connect-the-Dots: Bridging Semantics between Words and Definitions via Aligning Word Sense Inventories This repo is the code release of EMNLP 2021 con

null 12 Nov 22, 2022
This repository contains the code for our fast polygonal building extraction from overhead images pipeline.

Polygonal Building Segmentation by Frame Field Learning We add a frame field output to an image segmentation neural network to improve segmentation qu

Nicolas Girard 186 Jan 4, 2023
This repository contains the source code of our work on designing efficient CNNs for computer vision

Efficient networks for Computer Vision This repo contains source code of our work on designing efficient networks for different computer vision tasks:

Sachin Mehta 386 Nov 26, 2022
This repository contains the entire code for our work "Two-Timescale End-to-End Learning for Channel Acquisition and Hybrid Precoding"

Two-Timescale-DNN Two-Timescale End-to-End Learning for Channel Acquisition and Hybrid Precoding This repository contains the entire code for our work

QiyuHu 3 Mar 7, 2022
Woosung Choi 63 Nov 14, 2022
Github for the conference paper GLOD-Gaussian Likelihood OOD detector

FOOD - Fast OOD Detector Pytorch implamentation of the confernce peper FOOD arxiv link. Abstract Deep neural networks (DNNs) perform well at classifyi

null 17 Jun 19, 2022
Abstractive opinion summarization system (SelSum) and the largest dataset of Amazon product summaries (AmaSum). EMNLP 2021 conference paper.

Learning Opinion Summarizers by Selecting Informative Reviews This repository contains the codebase and the dataset for the corresponding EMNLP 2021

Arthur Bražinskas 39 Jan 1, 2023
Main repository for the HackBio'2021 Virtual Internship Experience for #Team-Greider ❤️

Hello ?? #Team-Greider The team of 20 people for HackBio'2021 Virtual Bioinformatics Internship ?? ??️ ??‍?? HackBio: https://thehackbio.com ?? Ask us

Siddhant Sharma 7 Oct 20, 2022
Dcf-game-infrastructure-public - Contains all the components necessary to run a DC finals (attack-defense CTF) game from OOO

dcf-game-infrastructure All the components necessary to run a game of the OOO DC

Order of the Overflow 46 Sep 13, 2022