ahmedbesbes / character-based-cnn

Implementation of character based convolutional neural network
MIT License
258 stars 56 forks source link
character-based-model character-cnn convolutional-neural-network deep-neural-networks natural-language-processing nlp nlp-machine-learning paper-implementations pytorch youtube-video

Character Based CNN

MIT contributions welcome Twitter Stars

This repo contains a PyTorch implementation of a character-level convolutional neural network for text classification.

The model architecture comes from this paper: https://arxiv.org/pdf/1509.01626.pdf

Network architecture

There are two variants: a large and a small. You can switch between the two by changing the configuration file.

This architecture has 6 convolutional layers:

Layer Large Feature Small Feature Kernel Pool
1 1024 256 7 3
2 1024 256 7 3
3 1024 256 3 N/A
4 1024 256 3 N/A
5 1024 256 3 N/A
6 1024 256 3 3

and 2 fully connected layers:

Layer Output Units Large Output Units Small
7 2048 1024
8 2048 1024
9 Depends on the problem Depends on the problem

Video tutorial

If you're interested in how character CNN work as well as in the demo of this project you can check my youtube video tutorial.

Why you should care about character level CNNs

They have very nice properties:

Training a sentiment classifier on french customer reviews

I have tested this model on a set of french labeled customer reviews (of over 3 millions rows). I reported the metrics in TensorboardX.

I got the following results

F1 score Accuracy
train 0.965 0.9366
test 0.945 0.915

Training metrics

Dependencies

Structure of the code

At the root of the project, you will have:

How to use the code

Training

The code currently works only on binary labels (0/1)

Launch train.py with the following arguments:

Example usage:

python train.py --data_path=/data/tweets.csv --max_rows=200000

Plotting results to TensorboardX

Run this command at the root of the project:

tensorboard --logdir=./logs/ --port=6006

Then go to: http://localhost:6006 (or whatever host you're using)

Prediction

Launch predict.py with the following arguments:

Example usage:

python predict.py ./models/pretrained_model.pth --text="I love pizza !" --max_length=150

Download pretrained models

Contributions - PR are welcome:

Here's a non-exhaustive list of potential future features to add:

License

This project is licensed under the MIT License