CNNs for Sentence Classification in PyTorch

Overview

Introduction

This is the implementation of Kim's Convolutional Neural Networks for Sentence Classification paper in PyTorch.

  1. Kim's implementation of the model in Theano: https://github.com/yoonkim/CNN_sentence
  2. Denny Britz has an implementation in Tensorflow: https://github.com/dennybritz/cnn-text-classification-tf
  3. Alexander Rakhlin's implementation in Keras; https://github.com/alexander-rakhlin/CNN-for-Sentence-Classification-in-Keras

Requirement

  • python 3
  • pytorch > 0.1
  • torchtext > 0.1
  • numpy

Result

I just tried two dataset, MR and SST.

Dataset Class Size Best Result Kim's Paper Result
MR 2 77.5%(CNN-rand-static) 76.1%(CNN-rand-nostatic)
SST 5 37.2%(CNN-rand-static) 45.0%(CNN-rand-nostatic)

I haven't adjusted the hyper-parameters for SST seriously.

Usage

./main.py -h

or

python3 main.py -h

You will get:

CNN text classificer

optional arguments:
  -h, --help            show this help message and exit
  -batch-size N         batch size for training [default: 50]
  -lr LR                initial learning rate [default: 0.01]
  -epochs N             number of epochs for train [default: 10]
  -dropout              the probability for dropout [default: 0.5]
  -max_norm MAX_NORM    l2 constraint of parameters
  -cpu                  disable the gpu
  -device DEVICE        device to use for iterate data
  -embed-dim EMBED_DIM
  -static               fix the embedding
  -kernel-sizes KERNEL_SIZES
                        Comma-separated kernel size to use for convolution
  -kernel-num KERNEL_NUM
                        number of each kind of kernel
  -class-num CLASS_NUM  number of class
  -shuffle              shuffle the data every epoch
  -num-workers NUM_WORKERS
                        how many subprocesses to use for data loading
                        [default: 0]
  -log-interval LOG_INTERVAL
                        how many batches to wait before logging training
                        status
  -test-interval TEST_INTERVAL
                        how many epochs to wait before testing
  -save-interval SAVE_INTERVAL
                        how many epochs to wait before saving
  -predict PREDICT      predict the sentence given
  -snapshot SNAPSHOT    filename of model snapshot [default: None]
  -save-dir SAVE_DIR    where to save the checkpoint

Train

./main.py

You will get:

Batch[100] - loss: 0.655424  acc: 59.3750%
Evaluation - loss: 0.672396  acc: 57.6923%(615/1066) 

Test

If you has construct you test set, you make testing like:

/main.py -test -snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt

The snapshot option means where your model load from. If you don't assign it, the model will start from scratch.

Predict

  • Example1

     ./main.py -predict="Hello my dear , I love you so much ." \
               -snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt" 
    

    You will get:

     Loading model from [./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt]...
     
     [Text]  Hello my dear , I love you so much .
     [Label] positive
    
  • Example2

     ./main.py -predict="You just make me so sad and I have to leave you ."\
               -snapshot="./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt" 
    

    You will get:

     Loading model from [./snapshot/2017-02-11_15-50-53/snapshot_steps1500.pt]...
     
     [Text]  You just make me so sad and I have to leave you .
     [Label] negative
    

Your text must be separated by space, even punctuation.And, your text should longer then the max kernel size.

Reference

Comments
  • Issue with the prediction function

    Issue with the prediction function

    I've been playing with the CNN text classification lately, and it seems it trains fine, but when it comes to predictions I get this error:

    Traceback (most recent call last): File "main.py", line 89, in <module> label = train.predict(args.predict, cnn, text_field, label_field) File "train.py", line 88, in predict return label_feild.vocab.itos[predicted.data[0][0]+1] TypeError: 'int' object has no attribute '__getitem__'

    I changed the return label_feild.vocab.itos[predicted.data[0][0]+1] to return label_feild.vocab.itos[predicted.data[0]+1] and bypassed the error, but most predictions are not accurate now. Can you please let me know what I'm missing here.

    opened by sabersf 3
  • Embedding Static

    Embedding Static

    I thing the way this code implements Static Embedding is wrong. Am I right? This code, uses x=variable(x) when wants to make the embedding static, while it should be something like: self.embed.weight.requires_grad=False

    opened by ASoleimaniB 2
  • prediction error

    prediction error

    training seems to work fine , when I run "python main.py"

    but when I ran predict "python main.py -predict="I feel bad" -snapshot="snapshot/2018-01-08_13-10-29/snapshot_steps9000.pt", it is giving me this error

    RuntimeError: Given input size: (1, 3, 128). Calculated output size: (1, 0, 1). Output size is too small.

    I haven't made any changes to code

    opened by nitish116 2
  • Output raw input sentences

    Output raw input sentences

    Is there a easy way to get the original raw sentences instead of data objects during the test dataset evaluation?

    def eval(data_iter, model, args):
        model.eval()
        corrects, avg_loss = 0, 0
        for batch in data_iter:
            feature, target = batch.text, batch.label
            **print (feature.original_sentence)**
    
    opened by sunyangfu 1
  • Preprocessing issue in mydatasets.py

    Preprocessing issue in mydatasets.py

    I was reading the documentation for the Torchtext Field object and I noticed that preprocessing happens after tokenization. This seems to conflict with the intention of the clean_str function, as adding it to the text field's preprocessing will split contractions, etc. on individual tokens (causing tokens with spaces in them) rather than an entire sentence. To fix this, the following statement on line 74:

    text_field.preprocessing = data.Pipeline(clean_str)

    can be replaced with something like this:

    text_field.tokenize = lambda x: clean_str(x).split()

    which will apply clean_str before tokenization (str.split() is the default tokenizer used by the Field object).

    opened by rriva002 1
  • Fixed breaking changes in predict function introduced by pytorch 0.4

    Fixed breaking changes in predict function introduced by pytorch 0.4

    Pytorch 0.4 removes tensor_type from torchtext data field and causes predict function from train.py to break.

    Switched over to getting the tensor from torch based on the migration guide: https://pytorch.org/blog/pytorch-0_4_0-migration-guide/

    opened by Rohan-B 1
  • Issue with running code on SST Dataset

    Issue with running code on SST Dataset

    Hi @Shawn1993

    I have an issue with running the code on SST dataset. I comment line 73 and uncomment line 74 in main.py

    It seems that the code will directly use the torchtext dataset package to download the SST dataset, and then run. However this will raise the error RuntimeError: Given input size: (1, 4, 128). Calculated output size: (64, 0, 1). Output size is too small. in line 44 of model.py x = [F.max_pool1d(i, i.size(2)).squeeze(2) for i in x] # [(N, Co), ...]*len(Ks).

    Any hint on this possible reason? Thanks in advance.

    For your information, I am running on PyTorch '0.3.1.post2' and torchtext '0.2.1'.

    opened by mickeysjm 1
  • Pull request for issue #27

    Pull request for issue #27

    GitHub seems to be having some issues since the master branch for my fork of this project is heavily modified. The only actual change in this pull request should be line 74 of mydatasets.py.

    opened by rriva002 0
  • Data load error in Python2

    Data load error in Python2

    Trying to run the model in python2.7 (can't upgrade my system python so I have to make do). I am getting the following error:

    Loading data...
    Traceback (most recent call last):
      File "./main.py", line 73, in <module>
        train_iter, dev_iter = mr(text_field, label_field, device=-1, repeat=False)
      File "./main.py", line 59, in mr
        train_data, dev_data = mydatasets.MR.splits(text_field, label_field)
      File "/home/cnn-text-classification-pytorch/mydatasets.py", line 105, in splits
        examples = cls(text_field, label_field, path=path, **kwargs).examples
      File "/home/cnn-text-classification-pytorch/mydatasets.py", line 82, in __init__
        data.Example.fromlist([line, 'negative'], fields) for line in f]
      File "/home/venv/local/lib/python2.7/site-packages/torchtext/data/example.py", line 52, in fromlist
        setattr(ex, name, field.preprocess(val))
      File "/home/venv/local/lib/python2.7/site-packages/torchtext/data/field.py", line 166, in preprocess
        x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x)
      File "/home/venv/local/lib/python2.7/site-packages/torchtext/data/pipeline.py", line 37, in __call__
        x = pipe.call(x, *args)
      File "/home/venv/local/lib/python2.7/site-packages/torchtext/data/pipeline.py", line 53, in call
        return self.convert_token(x, *args)
      File "/home/venv/local/lib/python2.7/site-packages/torchtext/data/field.py", line 166, in <lambda>
        x = Pipeline(lambda s: six.text_type(s, encoding='utf-8'))(x)
    UnicodeDecodeError: 'utf8' codec can't decode byte 0x97 in position 110: invalid start byte
    

    I've decoded all the data in rt-polarity.neg and rt-polarity.pos to UTF-8 ignoring errors (to remove the non-decodeable charaacters), but no luck. Any help?

    opened by sarahwie 0
  • update details for better use

    update details for better use

    This pull request is mainly do four things as follow:

    1. Solve the basic import problem in python3.
    2. Add -early-stop and -save-best option. I run your code in K80. It's very easy to overfitting. You use iteration not epoch, so I do the same.
    3. Update model saving way. It should be avoided to save a whole model as the author smouth says. If so, you will meet many problems when fine-tune, you must keep everything environment the same to run well. So I use state_dict() and load_state_dict() instead.
    4. Do a little changes to fit pep8.
    opened by oneTaken 0
  • Fixed cuda tensor bug during prediction

    Fixed cuda tensor bug during prediction

    Also made two other minor modifications:

    1. Added space between text and label during prediction
    2. Added KeyboardInterrupt functionality for quitting training
    opened by SeanRosario 0
  • CVE-2007-4559 Patch

    CVE-2007-4559 Patch

    Patching CVE-2007-4559

    Hi, we are security researchers from the Advanced Research Center at Trellix. We have began a campaign to patch a widespread bug named CVE-2007-4559. CVE-2007-4559 is a 15 year old bug in the Python tarfile package. By using extract() or extractall() on a tarfile object without sanitizing input, a maliciously crafted .tar file could perform a directory path traversal attack. We found at least one unsantized extractall() in your codebase and are providing a patch for you via pull request. The patch essentially checks to see if all tarfile members will be extracted safely and throws an exception otherwise. We encourage you to use this patch or your own solution to secure against CVE-2007-4559. Further technical information about the vulnerability can be found in this blog.

    If you have further questions you may contact us through this projects lead researcher Kasimir Schulz.

    opened by TrellixVulnTeam 0
  • 'Field' object has no attribute 'vocab'

    'Field' object has no attribute 'vocab'

    Hello, excuse me, there was no problem during the training, but this error occurred during the prediction. I'm actually extracting predict as a function

    PATH = './snapshot/best_steps_8600.pt' args = confog_args() text_field = data.Field(lower=True) label_field = data.Field(sequential=False) args.vocabulary_size = len(text_field.vocab) args.cuda = args.device != -1 and torch.cuda.is_available()

    In addition, the training data should also be loaded when predicting ?????

    Looking forward to your reply. Thank you

    opened by SevenMpp 0
  • Performance on MR dataset

    Performance on MR dataset

    Hi there, I see that the best reported accuracy for this repo for the MR dataset is 77.5% using CNN-rand-static. When I run this, using ./main.py -device=0 -static, I get much lower numbers (~70%). Two questions:

    1. What training settings are you using to get 77.5%?
    2. How are you evaluating on the MR dataset to get 77.5%?

    Thanks!

    opened by nick11roberts 2
  • RuntimeError: set_storage_offset is not allowed on Tensor created from .data or .detach()

    RuntimeError: set_storage_offset is not allowed on Tensor created from .data or .detach()

    • 问题1:
    Traceback (most recent call last):
      File "/cnn-text-classification-pytorch/main.py", line 112, in <module>
        train.train(train_iter, dev_iter, cnn, args)
      File "/cnn-text-classification-pytorch/train.py", line 25, in train
        feature.data.t_(), target.data.sub_(1)  # batch first, index align
    RuntimeError: set_storage_offset is not allowed on Tensor created from .data or .detach()
    
    Process finished with exit code 1
    
    • 问题1解决:将【2处】feature.data.t_(), target.data.sub_(1)替换为:
     feature = feature.data.t()
     target = target.data.sub(1) 
    
    • 问题2:
    Traceback (most recent call last):
      File "/cnn-text-classification-pytorch/main.py", line 112, in <module>
        train.train(train_iter, dev_iter, cnn, args)
      File "/cnn-text-classification-pytorch/train.py", line 43, in train
        loss.data[0],
    IndexError: invalid index of a 0-dim tensor. Use tensor.item() to convert a 0-dim tensor to a Python number
    
    Process finished with exit code 1
    
    • 问题2解决:将【2处】loss.data[0]替换为:loss.item()
    opened by jrothschild33 1
Owner
Shawn Ng
Now, I focus on the Natural Language Processing, such as QA
Shawn Ng
Many Class Activation Map methods implemented in Pytorch for CNNs and Vision Transformers. Including Grad-CAM, Grad-CAM++, Score-CAM, Ablation-CAM and XGrad-CAM

Class Activation Map methods implemented in Pytorch pip install grad-cam ⭐ Tested on many Common CNN Networks and Vision Transformers. ⭐ Includes smoo

Jacob Gildenblat 6.6k Jan 6, 2023
Equivariant CNNs for the sphere and SO(3) implemented in PyTorch

Equivariant CNNs for the sphere and SO(3) implemented in PyTorch

Jonas Köhler 893 Dec 28, 2022
Pytorch Implementations of large number classical backbone CNNs, data enhancement, torch loss, attention, visualization and some common algorithms.

Torch-template-for-deep-learning Pytorch implementations of some **classical backbone CNNs, data enhancement, torch loss, attention, visualization and

Li Shengyan 270 Dec 31, 2022
📦 PyTorch based visualization package for generating layer-wise explanations for CNNs.

Explainable CNNs ?? Flexible visualization package for generating layer-wise explanations for CNNs. It is a common notion that a Deep Learning model i

Ashutosh Hathidara 183 Dec 15, 2022
Spherical CNNs

Spherical CNNs Equivariant CNNs for the sphere and SO(3) implemented in PyTorch Overview This library contains a PyTorch implementation of the rotatio

Jonas Köhler 893 Dec 28, 2022
Study of human inductive biases in CNNs and Transformers.

Are Convolutional Neural Networks or Transformers more like human vision? This repository contains the code and fine-tuned models of popular Convoluti

Shikhar Tuli 39 Dec 8, 2022
Training RNNs as Fast as CNNs

News SRU++, a new SRU variant, is released. [tech report] [blog] The experimental code and SRU++ implementation are available on the dev branch which

ASAPP Research 2.1k Jan 1, 2023
GAN-generated image detection based on CNNs

GAN-image-detection This repository contains a GAN-generated image detector developed to distinguish real images from synthetic ones. The detector is

Image and Sound Processing Lab 17 Dec 15, 2022
VOneNet: CNNs with a Primary Visual Cortex Front-End

VOneNet: CNNs with a Primary Visual Cortex Front-End A family of biologically-inspired Convolutional Neural Networks (CNNs). VOneNets have the followi

The DiCarlo Lab at MIT 99 Dec 22, 2022
It's a implement of this paper:Relation extraction via Multi-Level attention CNNs

Relation Classification via Multi-Level Attention CNNs It's a implement of this paper:Relation Classification via Multi-Level Attention CNNs. Training

Aybss 2 Nov 4, 2022
This repository contains the source code of our work on designing efficient CNNs for computer vision

Efficient networks for Computer Vision This repo contains source code of our work on designing efficient networks for different computer vision tasks:

Sachin Mehta 386 Nov 26, 2022
A light weight data augmentation tool for training CNNs and Viola Jones detectors

hey-daug A light weight data augmentation tool for training CNNs and Viola Jones detectors (Haar Cascades). This tool inflates your data by up to six

Jaiyam Sharma 2 Nov 23, 2019
This repository provides the official implementation of 'Learning to ignore: rethinking attention in CNNs' accepted in BMVC 2021.

inverse_attention This repository provides the official implementation of 'Learning to ignore: rethinking attention in CNNs' accepted in BMVC 2021. Le

Firas Laakom 5 Jul 8, 2022
[CVPRW 2022] Attentions Help CNNs See Better: Attention-based Hybrid Image Quality Assessment Network

Attention Helps CNN See Better: Hybrid Image Quality Assessment Network [CVPRW 2022] Code for Hybrid Image Quality Assessment Network [paper] [code] T

IIGROUP 49 Dec 11, 2022
Simple-Image-Classification - Simple Image Classification Code (PyTorch)

Simple-Image-Classification Simple Image Classification Code (PyTorch) Yechan Kim This repository contains: Python3 / Pytorch code for multi-class ima

Yechan Kim 8 Oct 29, 2022
This repository contains the PyTorch implementation of the paper STaCK: Sentence Ordering with Temporal Commonsense Knowledge appearing at EMNLP 2021.

STaCK: Sentence Ordering with Temporal Commonsense Knowledge This repository contains the pytorch implementation of the paper STaCK: Sentence Ordering

Deep Cognition and Language Research (DeCLaRe) Lab 23 Dec 16, 2022
Implement face detection, and age and gender classification, and emotion classification.

YOLO Keras Face Detection Implement Face detection, and Age and Gender Classification, and Emotion Classification. (image from wider face dataset) Ove

Chloe 10 Nov 14, 2022
Image Classification - A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

A research on image classification and auto insurance claim prediction, a systematic experiments on modeling techniques and approaches

null 0 Jan 23, 2022
Hl classification bc - A Network-Based High-Level Data Classification Algorithm Using Betweenness Centrality

A Network-Based High-Level Data Classification Algorithm Using Betweenness Centr

Esteban Vilca 3 Dec 1, 2022