Character Based CNN
This repo contains a PyTorch implementation of a character-level convolutional neural network for text classification.
The model architecture comes from this paper: https://arxiv.org/pdf/1509.01626.pdf
There are two variants: a large and a small. You can switch between the two by changing the configuration file.
This architecture has 6 convolutional layers:
Layer | Large Feature | Small Feature | Kernel | Pool |
---|---|---|---|---|
1 | 1024 | 256 | 7 | 3 |
2 | 1024 | 256 | 7 | 3 |
3 | 1024 | 256 | 3 | N/A |
4 | 1024 | 256 | 3 | N/A |
5 | 1024 | 256 | 3 | N/A |
6 | 1024 | 256 | 3 | 3 |
and 2 fully connected layers:
Layer | Output Units Large | Output Units Small |
---|---|---|
7 | 2048 | 1024 |
8 | 2048 | 1024 |
9 | Depends on the problem | Depends on the problem |
Video tutorial
If you're interested in how character CNN work as well as in the demo of this project you can check my youtube video tutorial.
Why you should care about character level CNNs
They have very nice properties:
- They are quite powerful in text classification (see paper's benchmark) even though they don't have any notion of semantics
- You don't need to apply any text preprocessing (tokenization, lemmatization, stemming ...) while using them
- They handle misspelled words and OOV (out-of-vocabulary) tokens
- They are faster to train compared to recurrent neural networks
- They are lightweight since they don't require storing a large word embedding matrix. Hence, you can deploy them in production easily
Training a sentiment classifier on french customer reviews
I have tested this model on a set of french labeled customer reviews (of over 3 millions rows). I reported the metrics in TensorboardX.
I got the following results
F1 score | Accuracy | |
---|---|---|
train | 0.965 | 0.9366 |
test | 0.945 | 0.915 |
Dependencies
- numpy
- pandas
- sklearn
- PyTorch 0.4.1
- tensorboardX
- Tensorflow (to be able to run TensorboardX)
Structure of the code
At the root of the project, you will have:
- train.py: used for training a model
- predict.py: used for the testing and inference
- config.json: a configuration file for storing model parameters (number of filters, neurons)
- src: a folder that contains:
- cnn_model.py: the actual CNN model (model initialization and forward method)
- data_loader.py: the script responsible of passing the data to the training after processing it
- utils.py: a set of utility functions for text preprocessing (url/hashtag/user_mention removal)
How to use the code
Training
The code currently works only on binary labels (0/1)
Launch train.py with the following arguments:
data_path
: path of the data. Data should be in csv format with at least a column for text and a column for the labelvalidation_split
: the ratio of validation data. default to 0.2label_column
: column name of the labelstext_column
: column name of the textsmax_rows
: the maximum number of rows to load from the dataset. (I mainly use this for testing to go faster)chunksize
: size of the chunks when loading the data using pandas. default to 500000encoding
: default to utf-8steps
: text preprocessing steps to include on the text like hashtag or url removalgroup_labels
: whether or not to group labels. Default to None.use_sampler
: whether or not to use a weighted sampler to overcome class imbalancealphabet
: default to abcdefghijklmnopqrstuvwxyz0123456789,;.!?:'"/\|_@#$%^&*~`+-=<>()[]{} (normally you should not modify it)number_of_characters
: default 70extra_characters
: additional characters that you'd add to the alphabet. For example uppercase letters or accented charactersmax_length
: the maximum length to fix for all the documents. default to 150 but should be adapted to your dataepochs
: number of epochsbatch_size
: batch size, default to 128.optimizer
: adam or sgd, default to sgdlearning_rate
: default to 0.01class_weights
: whether or not to use class weights in the cross entropy lossfocal_loss
: whether or not to use the focal lossgamma
: gamma parameter of the focal loss. default to 2alpha
: alpha parameter of the focal loss. default to 0.25schedule
: number of epochs by which the learning rate decreases by half (learning rate scheduling works only for sgd), default to 3. set it to 0 to disable itpatience
: maximum number of epochs to wait without improvement of the validation loss, default to 3early_stopping
: to choose whether or not to early stop the training. default to 0. set to 1 to enable it.checkpoint
: to choose to save the model on disk or not. default to 1, set to 0 to disable model checkpointworkers
: number of workers in PyTorch DataLoader, default to 1log_path
: path of tensorboard log fileoutput
: path of the folder where models are savedmodel_name
: prefix name of saved models
Example usage:
python train.py --data_path=/data/tweets.csv --max_rows=200000
Plotting results to TensorboardX
Run this command at the root of the project:
tensorboard --logdir=./logs/ --port=6006
Then go to: http://localhost:6006 (or whatever host you're using)
Prediction
Launch predict.py with the following arguments:
model
: path of the pre-trained modeltext
: input textsteps
: list of preprocessing steps, default to loweralphabet
: default to 'abcdefghijklmnopqrstuvwxyz0123456789-,;.!?:'"\/|_@#$%^&*~`+-=<>()[]{}\n'number_of_characters
: default to 70extra_characters
: additional characters that you'd add to the alphabet. For example uppercase letters or accented charactersmax_length
: the maximum length to fix for all the documents. default to 150 but should be adapted to your data
Example usage:
python predict.py ./models/pretrained_model.pth --text="I love pizza !" --max_length=150
Download pretrained models
-
Sentiment analysis model on French customer reviews (3M documents): download link
When using it:
- set max_length to 300
- use extra_characters="éàèùâêîôûçëïü" (accented letters)
Contributions - PR are welcome:
Here's a non-exhaustive list of potential future features to add:
- Adapt the loss for multi-class classification
- Log training and validation metrics for each epoch to a text file
- Provide notebook tutorials
License
This project is licensed under the MIT License