This repo is deprecated. Please find the updated package here.
https://github.com/EdGENetworks/anuvada
Anuvada: Interpretable Models for NLP using PyTorch
One of the common criticisms of deep learning has been it's black box nature. To address this issue, researchers have developed many ways to visualise and explain the inference. Some examples would be attention in the case of RNN's, activation maps, guided back propagation and occlusion (in the case of CNN's). This library is an ongoing effort to provide a high-level access to such models relying on PyTorch.
Installing
Clone this repo and add it to your python library path.
Getting started
Importing libraries
import anuvada
import numpy as np
import torch
import pandas as pd
from anuvada.models.classification_attention_rnn import AttentionClassifier
Creating the dataset
from anuvada.datasets.data_loader import CreateDataset
from anuvada.datasets.data_loader import LoadData
data = CreateDataset()
df = pd.read_csv('MovieSummaries/movie_summary_filtered.csv')
# passing only the first 512 samples, I don't have a GPU!
y = list(df.Genre.values)[0:512]
x = list(df.summary.values)[0:512]
x, y = data.create_dataset(x,y, folder_path='test', max_doc_tokens=500)
Loading created dataset
l = LoadData()
x, y, token2id, label2id, lengths_mask = l.load_data_from_path('test')
Change into torch vectors
x = torch.from_numpy(x)
y = torch.from_numpy(y)
Create attention classifier
acf = AttentionClassifier(vocab_size=len(token2id),embed_size=25,gru_hidden=25,n_classes=len(label2id))
loss = acf.fit(x,y, lengths_mask ,epochs=5)
Epoch 1 / 5
[========================================] 100% loss: 3.9904loss: 3.9904
Epoch 2 / 5
[========================================] 100% loss: 3.9851loss: 3.9851
Epoch 3 / 5
[========================================] 100% loss: 3.9783loss: 3.9783
Epoch 4 / 5
[========================================] 100% loss: 3.9739loss: 3.9739
Epoch 5 / 5
[========================================] 100% loss: 3.9650loss: 3.9650
To do list
- Implement Attention with RNN
- Implement Attention Visualisation
- Implement working Fit Module
- Implement support for masking gradients in RNN (Working now!)
- Implement a generic data set loader
- Implement CNN Classifier with feature map visualisation