MixText: Linguistically-Informed Interpolation of Hidden Space for Semi-Supervised Text Classification

Overview

MixText

This repo contains codes for the following paper:

Jiaao Chen, Zichao Yang, Diyi Yang: MixText: Linguistically-Informed Interpolation of Hidden Space for Semi-Supervised Text Classification. In Proceedings of the 58th Annual Meeting of the Association of Computational Linguistics (ACL'2020)

If you would like to refer to it, please cite the paper mentioned above.

Getting Started

These instructions will get you running the codes of MixText.

Requirements

  • Python 3.6 or higher
  • Pytorch >= 1.3.0
  • Pytorch_transformers (also known as transformers)
  • Pandas, Numpy, Pickle
  • Fairseq

Code Structure

|__ data/
        |__ yahoo_answers_csv/ --> Datasets for Yahoo Answers
            |__ back_translate.ipynb --> Jupyter Notebook for back translating the dataset
            |__ classes.txt --> Classes for Yahoo Answers dataset
            |__ train.csv --> Original training dataset
            |__ test.csv --> Original testing dataset
            |__ de_1.pkl --> Back translated training dataset with German as middle language
            |__ ru_1.pkl --> Back translated training dataset with Russian as middle language

|__code/
        |__ transformers/ --> Codes copied from huggingface/transformers
        |__ read_data.py --> Codes for reading the dataset; forming labeled training set, unlabeled training set, development set and testing set; building dataloaders
        |__ normal_bert.py --> Codes for BERT baseline model
        |__ normal_train.py --> Codes for training BERT baseline model
        |__ mixtext.py --> Codes for our proposed TMix/MixText model
        |__ train.py --> Codes for training/testing TMix/MixText 

Downloading the data

Please download the dataset and put them in the data folder. You can find Yahoo Answers, AG News, DB Pedia here, IMDB here.

Pre-processing the data

For Yahoo Answer, We concatenate the question title, question content and best answer together to form the text to be classified. The pre-processed Yahoo Answer dataset can be downloaded here.

Note that for AG News and DB Pedia, we only utilize the content (without titles) to do the classifications, and for IMDB we do not perform any pre-processing.

We utilize Fairseq to perform back translation on the training dataset. Please refer to ./data/yahoo_answers_csv/back_translate.ipynb for details.

Here, we have put two examples of back translated data, de_1.pkl and ru_1.pkl, in ./data/yahoo_answers_csv/ as well. You can directly use them for Yahoo Answers or generate your own back translated data followed the ./data/yahoo_answers_csv/back_translate.ipynb.

Training models

These section contains instructions for training models on Yahoo Answers using 10 labeled data per class for training.

Training BERT baseline model

Please run ./code/normal_train.py to train the BERT baseline model (only use labeled training data):

python ./code/normal_train.py --gpu 0,1 --n-labeled 10 --data-path ./data/yahoo_answers_csv/ \
--batch-size 8 --epochs 20 

Training TMix model

Please run ./code/train.py to train the TMix model (only use labeled training data):

python ./code/train.py --gpu 0,1 --n-labeled 10 --data-path ./data/yahoo_answers_csv/ \
--batch-size 8 --batch-size-u 1 --epochs 50 --val-iteration 20 \
--lambda-u 0 --T 0.5 --alpha 16 --mix-layers-set 7 9 12 --separate-mix True 

Training MixText model

Please run ./code/train.py to train the MixText model (use both labeled and unlabeled training data):

python ./code/train.py --gpu 0,1,2,3 --n-labeled 10 \
--data-path ./data/yahoo_answers_csv/ --batch-size 4 --batch-size-u 8 --epochs 20 --val-iteration 1000 \
--lambda-u 1 --T 0.5 --alpha 16 --mix-layers-set 7 9 12 \
--lrmain 0.000005 --lrlast 0.0005
Comments
  • Can you provide all back translation data?

    Can you provide all back translation data?

    Thanks for your interesting work! I am trying to do some experiments using your code. I find that it is too expensive to augment unlabled data through back translation by myself because of limited resource. I can only find back translation data of Yahoo dataset in this codebase. Can you provide all back translation data? Thanks!

    opened by ghost 7
  • Reproducing UDA Results

    Reproducing UDA Results

    Hi @diyiy @jiaaoc ,

    just a quick question: Can I also use the code found in this repo to reproduce your UDA results reported in the paper? If so, how?

    Thanks :)

    opened by timoschick 6
  • Unlabelled Data Formatting

    Unlabelled Data Formatting

    Hey,

    I was trying to test your model for my custom data which contains both labelled and unlabelled sentences. Though I am not sure how I need to format, structure and keep my unlabelled data.

    I was able to run it successfully on the preprocessed Yahoo dataset but there were only labelled examples present in it as far as I could observe.

    opened by RachitBansal 6
  • I want to know what 102 means

    I want to know what 102 means

    Sorry for disturbing you again. I tried to change the max_seq_len to 512, and change the BERT to Electra. but I'm not sure that 102 in this line of code needs to be changed

    https://github.com/GT-SALT/MixText/blob/f17198d98e1bbb012d8e33cd26e228dd43bb4673/code/train.py#L313

    opened by callmeYe 5
  • Question about the data augmentation.

    Question about the data augmentation.

    Hello, I am following your work! I have a question about the figure 2 in your paper. Why the figure shows that you only augment the unlabel data ? Looking forward your answer!

    opened by JHR0717 5
  • Pretty low accuracy for yahoo answers mixtext

    Pretty low accuracy for yahoo answers mixtext

    Hi, Firstly I quite liked the paper and enjoyed reading it. I was trying to implement it and tried using this code for implementation on yahoo answers dataset but unfortunately the best accuracy does not cross 0.24. Since I have 2 gpus at my disposal I made the batch-size 2 and batch-size u 4 (also tried batch size 3 and batch size-u 4) and val-iteration 2000. So I was wondering, if you could let me know should I change other parameters to make it work? ( I understand that reducing the batch size should have some impact but not sure that the impact can be this significant so thought it might need to do with some other factors or weighting)?

    The command I used (i downloaded the dataset from the link provided and placed it in the directory)

    python ./code/train.py --gpu 0,1 --n-labeled 10 --data-path ./data/yahoo_answers_csv/ --batch-size 2 --batch-size-u 4 --epochs 20 --val-iteration 2000 --lambda-u 1 --T 0.5 --alpha 16 --mix layers-set 7 9 12 --lrmain 0.000005 --lrlast 0.0005

    The second question I had was the use of args.val_iteration? As I understand it is in a way number of batches to be processed in an epoch. So I was wondering how does it work if my number of labeled data per class is 10 == 100 labeled examples(for 10 classes in yahoo answers). So if I have batch size of 2 that would be 50 batches for data loader? So does it repeat instances in cyclic manner or just randomly pics 2 instances always?

    Thanks

    opened by sb1992 5
  • Sequences too long for backtranslation

    Sequences too long for backtranslation

    I was trying to use the backtranslation code with the IMDB polarity dataset. But the sequences are very long and fairseq is giving an error. How you handle such long sequences? Thanks for the code.

    opened by monkeysforever 4
  • Changing the sequence length for BERT & performance on data with multi-class labels.

    Changing the sequence length for BERT & performance on data with multi-class labels.

    I've been working on this and got results close to what was mentioned in the paper for the IMDB dataset. However, i'm facing a couple of issues.

    1. Changing the sequence length that BERT accepts hasn't been straightforward. I want it changed to 128. I couldn't see that in 'mixtext.py'. 'normal_bert.py' contains 'length=256' in the definition for 'forward' but I'm unsure if this the one to change as I don't see 'length' being used anywhere else. Please let me know how to change the sequence length for both normal_bert and mix_text.

    2. I tried to run this on the sklearn's NewsGroup dataset - which contains 20 labels. I've appropriately changed the 'read_data.py' to take in this data & also performed back-translations as mentioned. But I couldn't get good results with 'MixText'. Whatever the amount of labeled data that I use, I get almost similar, very low accuracies. Training time also remains the same irrespective of what the labeled data is. I'm assuming there is no issue with my data reading & pre-processing because 'normal_bert' runs fine. Please do let me know if there is any known issue with using MixText for a multi-class dataset.

    Thanks!

    opened by abhinivesh-s 3
  • AttributeError: 'BertConfig' object has no attribute 'chunk_size_feed_forward'

    AttributeError: 'BertConfig' object has no attribute 'chunk_size_feed_forward'

    Hello, what version of transformers are you using in the code? Tried to run and instance the MixText model but got an error of no attribute 'chunk_size_feed_forward' when creating the layers.

    model = MixText(n_labels, args['mix_option']).cuda()
    

    And using 'bert-base-uncased' as pretrained model.

    opened by dnnxl 2
  • Loss function for supervised loss

    Loss function for supervised loss

    I have a question about what exactly is the supervised loss mentioned in paper or here in the code this Lx. It is not exactly cross entropy so was a bit confused, a bit more explanation about this loss function would be really helpful?

    Lx = - torch.mean(torch.sum(F.log_softmax(outputs_x, dim=1) * targets_x, dim=1))

    opened by sb1992 2
  • Missing Special Inputs for the BERT such as <CLS> or <SEP>

    Missing Special Inputs for the BERT such as or

    Hi,

    I'm currently working on your repository, and it seems that some special tokens for the BERT are missing. (e.g. or )

    From my understanding, those special characters are not concatenated to the input text and are fed to the BERT model.

    Is it right? Or am I missing it?

    Thank you

    opened by JJumSSu 2
  • The pre-processed Yahoo Answer dataset

    The pre-processed Yahoo Answer dataset

    Hello: I am trying to run this project, but the linked URL of The pre-processed Yahoo Answer dataset is invalid. Could you please make up the URL of The pre-processed Yahoo Answer dataset?

    Thanks, Rosit

    opened by rrrosita 1
  • yml file of environment

    yml file of environment

    Hello!

    I am trying to reproduce your work, but I have issues with the dependencies. Can you provide the yml file of the conda environment that you used or at least specify the version of each package (transformers torch, ...)?

    Thanks, Taha

    opened by TahaAslani 0
  •  KeyError: unlabeled_train_iter.next()

    KeyError: unlabeled_train_iter.next()

    %run /code/train.py --gpu=0 --n-labeled=10 --data-path /yahoo_answers_csv/ --batch-size=4 --batch-size-u=8 --epochs=50 --val-iteration=20 --lambda-u=0 --T=0.5 --alpha=16 --mix-layers-set 7 9 12 --separate-mix=True

    train(labeled_trainloader, unlabeled_trainloader, model, optimizer, scheduler, criterion, epoch, n_labels, train_aug) 204 (inputs_u, inputs_u2, inputs_ori), (length_u, --> 205 length_u2, length_ori) = unlabeled_train_iter.next() 206 except:

    /torch/utils/data/dataloader.py in next(self) 520 self._reset() --> 521 data = self._next_data() 522 self._num_yielded += 1

    /torch/utils/data/dataloader.py in _next_data(self) 560 index = self._next_index() # may raise StopIteration --> 561 data = self._dataset_fetcher.fetch(index) # may raise StopIteration 562 if self._pin_memory:

    /torch/utils/data/_utils/fetch.py in fetch(self, possibly_batched_index) 43 if self.auto_collation: ---> 44 data = [self.dataset[idx] for idx in possibly_batched_index] 45 else:

    /torch/utils/data/_utils/fetch.py in (.0) 43 if self.auto_collation: ---> 44 data = [self.dataset[idx] for idx in possibly_batched_index] 45 else:

    /code/read_data.py in getitem(self, idx) 209 if self.aug is not None: --> 210 u, v, ori = self.aug(self.text[idx], self.ids[idx]) 211 encode_result_u, length_u = self.get_tokenized(u)

    /code/read_data.py in call(self, ori, idx) 22 def call(self, ori, idx): ---> 23 out1 = self.de[idx] 24 out2 = self.ru[idx]

    KeyError: 9226

    May I ask what is the reason for this place? Thank you very much.

    opened by SSSXHJDB 3
  • Parameter problem

    Parameter problem

    I recently studied your project and carefully read your paper on ACL. However, when I tried to run your code, I didn't find the setting of the parameter epoch, and the setting of other parameters in the code were vague, which was not fully specified in the paper, especially the parameter epoch. Therefore, I can't get the results in your paper. Can you provide me with a table on the parameter setting of each dataset?

    We look forward to your reply

    opened by CodingPerson 0
Owner
GT-SALT
Social and Language Technologies Lab
GT-SALT
FuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space OptimizationFuseDream: Training-Free Text-to-Image Generationwith Improved CLIP+GAN Space Optimization

FuseDream This repo contains code for our paper (paper link): FuseDream: Training-Free Text-to-Image Generation with Improved CLIP+GAN Space Optimizat

XCL 191 Dec 31, 2022
A library for hidden semi-Markov models with explicit durations

hsmmlearn hsmmlearn is a library for unsupervised learning of hidden semi-Markov models with explicit durations. It is a port of the hsmm package for

Joris Vankerschaver 69 Dec 20, 2022
A PyTorch implementation of "Semi-Supervised Graph Classification: A Hierarchical Graph Perspective" (WWW 2019)

SEAL ⠀⠀⠀ A PyTorch implementation of Semi-Supervised Graph Classification: A Hierarchical Graph Perspective (WWW 2019) Abstract Node classification an

Benedek Rozemberczki 202 Dec 27, 2022
Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021.

Semi-Supervised Graph Prototypical Networks for Hyperspectral Image Classification, IGARSS, 2021. Bobo Xi, Jiaojiao Li, Yunsong Li and Qian Du. Code f

Bobo Xi 7 Nov 3, 2022
Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

SSRL-for-image-classification Semi-supervised Representation Learning for Remote Sensing Image Classification Based on Generative Adversarial Networks

Feng 2 Nov 19, 2021
UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning

UniMoCo: Unsupervised, Semi-Supervised and Full-Supervised Visual Representation Learning This is the official PyTorch implementation for UniMoCo pape

dddzg 49 Jan 2, 2023
Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Project looking into use of autoencoder for semi-supervised learning and comparing data requirements compared to supervised learning.

Tom-R.T.Kvalvaag 2 Dec 17, 2021
Hybrid CenterNet - Hybrid-supervised object detection / Weakly semi-supervised object detection

Hybrid-Supervised Object Detection System Object detection system trained by hybrid-supervision/weakly semi-supervision (HSOD/WSSOD): This project is

null 5 Dec 10, 2022
Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Space robot - (Course Project) Using the space robot to capture the target satellite that is disabled and spinning, then stabilize and fix it up

Mingrui Yu 3 Jan 7, 2022
CoSMA: Convolutional Semi-Regular Mesh Autoencoder. From Paper "Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes"

Mesh Convolutional Autoencoder for Semi-Regular Meshes of Different Sizes Implementation of CoSMA: Convolutional Semi-Regular Mesh Autoencoder arXiv p

Fraunhofer SCAI 10 Oct 11, 2022
IDRLnet, a Python toolbox for modeling and solving problems through Physics-Informed Neural Network (PINN) systematically.

IDRLnet IDRLnet is a machine learning library on top of PyTorch. Use IDRLnet if you need a machine learning library that solves both forward and inver

IDRL 105 Dec 17, 2022
Must-read Papers on Physics-Informed Neural Networks.

PINNpapers Contributed by IDRL lab. Introduction Physics-Informed Neural Network (PINN) has achieved great success in scientific computing since 2017.

IDRL 330 Jan 7, 2023
A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery

PiSL A python implementation of Physics-informed Spline Learning for nonlinear dynamics discovery. Sun, F., Liu, Y. and Sun, H., 2021. Physics-informe

Fangzheng (Andy) Sun 8 Jul 13, 2022
Physics-informed convolutional-recurrent neural networks for solving spatiotemporal PDEs

PhyCRNet Physics-informed convolutional-recurrent neural networks for solving spatiotemporal PDEs Paper link: [ArXiv] By: Pu Ren, Chengping Rao, Yang

Pu Ren 11 Aug 23, 2022
PINN(s): Physics-Informed Neural Network(s) for von Karman vortex street

PINN(s): Physics-Informed Neural Network(s) for von Karman vortex street This is

ShotaDEGUCHI 2 Apr 18, 2022
Official implementation of "Learning Forward Dynamics Model and Informed Trajectory Sampler for Safe Quadruped Navigation" (RSS 2022)

Intro Official implementation of "Learning Forward Dynamics Model and Informed Trajectory Sampler for Safe Quadruped Navigation" Robotics:Science and

Yunho Kim 21 Dec 7, 2022
Supervised Classification from Text (P)

MSc-Thesis Module: Masters Research Thesis Language: Python Grade: 75 Title: An investigation of supervised classification of therapeutic process from

Matthew Laws 1 Nov 22, 2021
Implementation of STAM (Space Time Attention Model), a pure and simple attention model that reaches SOTA for video classification

STAM - Pytorch Implementation of STAM (Space Time Attention Model), yet another pure and simple SOTA attention model that bests all previous models in

Phil Wang 109 Dec 28, 2022
Working demo of the Multi-class and Anomaly classification model using the CLIP feature space

??️ Hindsight AI: Crime Classification With Clip About For Educational Purposes Only This is a recursive neural net trained to classify specific crime

Miles Tweed 2 Jun 5, 2022