fastnlp / fastNLP

fastNLP: A Modularized and Extensible NLP Framework. Currently still in incubation.
https://gitee.com/fastnlp/fastNLP
Apache License 2.0
3.06k stars 450 forks source link

What is the input format for Hierarchical Attention Network model using fastNLP framework #362

Open hengee opened 3 years ago

hengee commented 3 years ago

I saw that there is a Hierarchical Attention Network model included in the directory: reproduction/text_classification/model/HAN.py. I realized that the input for HAN is different from other models (LSTM and CNN):

HAN: input_sents (torch.LongTensor) -- [batch_size, num_sents, seq_len] CNN / LSTM: words (torch.LongTensor) -- [batch_size, seq_len]

Would like to know how to formulate input for HAN under fastNLP framework with the use of fastNLP DataSet and dataloader.

Thank you in advance!

yhcc commented 3 years ago

First, make sure your DataSet includes the input_sents field, then call DataSet.set_input('input_sents'), this function will deem the input_sents field as an input to the model forward function. And fastNLP will try to automatically pad this field(therefore, make sure this field is nested list with integer as elements) during iteration, then fastNLP will make the padded data as a torch.LongTensor and move it into the right device, then send to the forward function.

hengee commented 3 years ago

@yhcc okay noted, will try it out!

hengee commented 3 years ago

So I attempted to reproduce HAN for Document Classification using the code in this repo: fastNLP/reproduction/text_classification/model/HAN.py

I have listed my code, from generating my dataset to training the model, mostly adopted from:

  1. fastNLP/reproduction/text_classification/model/HAN.py
  2. fastNLP/reproduction/text_classification/train_HAN.py
  3. fastNLP/reproduction/HAN-document_classification/preprocess.py (got it from the previous commit)

import numpy as np # linear algebra
import pandas as pd # data processing, CSV file I/O (e.g. pd.read_csv)
import os

input_dir = 'user_text_label.csv'

from fastNLP.io.loader import CSVLoader
from fastNLP import Vocabulary
from fastNLP.core.const import Const as C
from fastNLP.core import LRScheduler
from fastNLP.embeddings import StaticEmbedding
from fastNLP import CrossEntropyLoss, AccuracyMetric
from fastNLP.core.trainer import Trainer
from torch.optim import SGD
import torch.cuda
from torch.optim.lr_scheduler import CosineAnnealingLR
import torch
import torch.nn as nn
from torch.autograd import Variable
from fastNLP.embeddings.utils import get_embeddings
from fastNLP.core import Const as C
from fastNLP.io import DataBundle
from fastNLP.core import field
# loading custom csv dataset
data_bundle = CSVLoader().load(input_dir)

# rename the label to target to fit the framework
data_bundle.rename_field('label', 'target')

# only train available
tr_data = data_bundle.get_dataset('train')

# splitting text into words (tokenize)
tr_data.apply(lambda ins: ins['tweet_text'].split(), new_field_name='words')

[train_data, test_data] = data_bundle.datasets['train'].split(0.3)

vocab = Vocabulary()
#  从该dataset中的tweet_text列建立词表
vocab.from_dataset(train_data, field_name='words', no_create_entry_dataset=[test_data])
vocab.index_dataset(train_data, field_name='words')

label_vocab = Vocabulary(padding=None, unknown=None)
label_vocab.from_dataset(train_data, field_name='target', no_create_entry_dataset=[test_data])
label_vocab.index_dataset(train_data, field_name='target')

new_data_bundle = DataBundle()
new_data_bundle.set_dataset(train_data, 'train')
new_data_bundle.set_dataset(test_data, 'test')
new_data_bundle.set_vocab(vocab, 'vocab')
new_data_bundle.set_vocab(label_vocab, 'label_vocab')

datainfo = new_data_bundle
print(len(datainfo.datasets['train']))
print(len(datainfo.datasets['test']))

# post process
def make_sents(words):
    sents = [words]
    return sents

for dataset in datainfo.datasets.values():
    dataset.apply_field(make_sents, field_name='words', new_field_name='input_sents')

datainfo = datainfo
datainfo.datasets['train'].set_input('input_sents')
datainfo.datasets['test'].set_input('input_sents')
datainfo.datasets['train'].set_target('target')
# datainfo.datasets['test'].set_target('target')

print(datainfo.datasets['train'])
print(len(datainfo.datasets['train']))
"""
+------------------+-------------------+--------+------------------+------------------+
| user_id          | tweet_text        | target | words            | input_sents      |
+------------------+-------------------+--------+------------------+------------------+
| 123              | When u wrappi...  | 0      | [166, 78, 136... | [[166, 78, 13... |
| 124              | RT @JessicaTa...  | 0      | [2, 540344, 1... | [[2, 540344, ... |
| 125              | it's so annoy...  | 0      | [80, 29, 1837... | [[80, 29, 183... |
+------------------+-------------------+--------+------------------+------------------+
"""
vocab = datainfo.vocabs['vocab']
# embedding = StackEmbedding([StaticEmbedding(vocab), CNNCharEmbedding(vocab, 100)])
embedding = StaticEmbedding(vocab, model_dir_or_name='.vector_cache/glove.6B.300d.txt')

print(len(vocab))

print(len(datainfo.vocabs['label_vocab']))

def pack_sequence(tensor_seq, padding_value=0.0):
    if len(tensor_seq) <= 0:
        return
    length = [v.size(0) for v in tensor_seq]
    max_len = max(length)
    size = [len(tensor_seq), max_len]
    size.extend(list(tensor_seq[0].size()[1:]))
    ans = torch.Tensor(*size).fill_(padding_value)
    if tensor_seq[0].data.is_cuda:
        ans = ans.cuda()
    ans = Variable(ans)
    for i, v in enumerate(tensor_seq):
        ans[i, :length[i], :] = v
    return ans

class HANCLS(nn.Module):
    def __init__(self, init_embed, num_cls):
        super(HANCLS, self).__init__()

        self.embed = get_embeddings(init_embed)
        self.han = HAN(input_size=300,
                       output_size=num_cls,
                       word_hidden_size=50, word_num_layers=1, word_context_size=100,
                       sent_hidden_size=50, sent_num_layers=1, sent_context_size=100
                       )

    def forward(self, input_sents):
        # input_sents [B, num_sents, seq-len] dtype long
        # target
        B, num_sents, seq_len = input_sents.size()
        input_sents = input_sents.view(-1, seq_len)  # flat
        words_embed = self.embed(input_sents)  # should be [B*num-sent, seqlen , word-dim]
        words_embed = words_embed.view(B, num_sents, seq_len, -1)  # recover # [B, num-sent, seqlen , word-dim]
        out = self.han(words_embed)

        return {C.OUTPUT: out}

    def predict(self, input_sents):
        x = self.forward(input_sents)[C.OUTPUT]
        return {C.OUTPUT: torch.argmax(x, 1)}

class HAN(nn.Module):
    def __init__(self, input_size, output_size,
                 word_hidden_size, word_num_layers, word_context_size,
                 sent_hidden_size, sent_num_layers, sent_context_size):
        super(HAN, self).__init__()

        self.word_layer = AttentionNet(input_size,
                                       word_hidden_size,
                                       word_num_layers,
                                       word_context_size)
        self.sent_layer = AttentionNet(2 * word_hidden_size,
                                       sent_hidden_size,
                                       sent_num_layers,
                                       sent_context_size)
        self.output_layer = nn.Linear(2 * sent_hidden_size, output_size)
        self.softmax = nn.LogSoftmax(dim=1)

    def forward(self, batch_doc):
        # input is a sequence of matrix
        doc_vec_list = []
        for doc in batch_doc:
            sent_mat = self.word_layer(doc)  # doc's dim (num_sent, seq_len, word_dim)
            doc_vec_list.append(sent_mat)  # sent_mat's dim (num_sent, vec_dim)
        doc_vec = self.sent_layer(pack_sequence(doc_vec_list))
        output = self.softmax(self.output_layer(doc_vec))
        return output

class AttentionNet(nn.Module):
    def __init__(self, input_size, gru_hidden_size, gru_num_layers, context_vec_size):
        super(AttentionNet, self).__init__()

        self.input_size = input_size
        self.gru_hidden_size = gru_hidden_size
        self.gru_num_layers = gru_num_layers
        self.context_vec_size = context_vec_size

        # Encoder
        self.gru = nn.GRU(input_size=input_size,
                          hidden_size=gru_hidden_size,
                          num_layers=gru_num_layers,
                          batch_first=True,
                          bidirectional=True)
        # Attention
        self.fc = nn.Linear(2 * gru_hidden_size, context_vec_size)
        self.tanh = nn.Tanh()
        self.softmax = nn.Softmax(dim=1)
        # context vector
        self.context_vec = nn.Parameter(torch.Tensor(context_vec_size, 1))
        self.context_vec.data.uniform_(-0.1, 0.1)

    def forward(self, inputs):
        # GRU part
        h_t, hidden = self.gru(inputs)  # inputs's dim (batch_size, seq_len,  word_dim)
        u = self.tanh(self.fc(h_t))
        # Attention part
        alpha = self.softmax(torch.matmul(u, self.context_vec))  # u's dim (batch_size, seq_len, context_vec_size)
        output = torch.bmm(torch.transpose(h_t, 1, 2), alpha)  # alpha's dim (batch_size, seq_len, 1)
        return torch.squeeze(output, dim=2)  # output's dim (batch_size, 2*hidden_size, 1)

model = HANCLS(init_embed=embedding, num_cls=len(label_vocab))
## 3. 声明loss,metric,optimizer
loss = CrossEntropyLoss(pred=C.OUTPUT, target=C.TARGET)
metric = AccuracyMetric(pred=C.OUTPUT, target=C.TARGET)
optimizer = SGD([param for param in model.parameters() if param.requires_grad == True], lr=0.001, momentum=0.9, weight_decay=0)

callbacks = []
callbacks.append(LRScheduler(CosineAnnealingLR(optimizer, 5)))

device = 'cuda:0' if torch.cuda.is_available() else 'cpu'

print(device)

for ds in data_bundle.datasets.values():
    ds.apply_field(len, C.INPUT, C.INPUT_LEN)
    ds.set_input(C.INPUT, C.INPUT_LEN)
    ds.set_target(C.TARGET)

np.warnings.filterwarnings('ignore', category=np.VisibleDeprecationWarning)   

## 4.定义train方法
def train(model, data_bundle, loss, metrics, optimizer,num_epochs=1):
    trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, metrics=[metrics], dev_data=datainfo.datasets['test'], device=device,check_code_level=-1, batch_size=4096, callbacks=callbacks, n_epochs=num_epochs)

    trainer.train()

train(model, data_bundle, loss, metric, optimizer)
"""
---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
<ipython-input-10-95c83efcfd68> in <module>
      5     trainer.train()
      6 
----> 7 train(model, data_bundle, loss, metric, optimizer)

<ipython-input-10-95c83efcfd68> in train(model, data_bundle, loss, metrics, optimizer, num_epochs)
      3     trainer = Trainer(datainfo.datasets['train'], model, optimizer=optimizer, loss=loss, metrics=[metrics], dev_data=datainfo.datasets['test'], device=device,check_code_level=-1, batch_size=4096, callbacks=callbacks, n_epochs=num_epochs)
      4 
----> 5     trainer.train()
      6 
      7 train(model, data_bundle, loss, metric, optimizer)

~/anaconda3/envs/venv/lib/python3.6/site-packages/fastNLP/core/trainer.py in train(self, load_best_model, on_exception)
    620                 if on_exception == 'auto':
    621                     if not isinstance(e, (CallbackException, KeyboardInterrupt)):
--> 622                         raise e
    623                 elif on_exception == 'raise':
    624                     raise e

~/anaconda3/envs/venv/lib/python3.6/site-packages/fastNLP/core/trainer.py in train(self, load_best_model, on_exception)
    613             try:
    614                 self.callback_manager.on_train_begin()
--> 615                 self._train()
    616                 self.callback_manager.on_train_end()
    617 

~/anaconda3/envs/venv/lib/python3.6/site-packages/fastNLP/core/trainer.py in _train(self)
    711                 # ================= mini-batch end ==================== #
    712                 if self.validate_every<0 and self.dev_data is not None:  # 在epoch结束之后的evaluate
--> 713                     eval_res = self._do_validation(epoch=epoch, step=self.step)
    714                     eval_str = "Evaluation on dev at Epoch {}/{}. Step:{}/{}: ".format(epoch, self.n_epochs, self.step,
    715                                                                                        self.n_steps)

~/anaconda3/envs/venv/lib/python3.6/site-packages/fastNLP/core/trainer.py in _do_validation(self, epoch, step)
    726     def _do_validation(self, epoch, step):
    727         self.callback_manager.on_valid_begin()
--> 728         res = self.tester.test()
    729 
    730         is_better_eval = False

~/anaconda3/envs/venv/lib/python3.6/site-packages/fastNLP/core/tester.py in test(self)
    173                     for batch_x, batch_y in data_iterator:
    174                         _move_dict_value_to_device(batch_x, batch_y, device=self._model_device)
--> 175                         pred_dict = self._data_forward(self._predict_func, batch_x)
    176                         if not isinstance(pred_dict, dict):
    177                             raise TypeError(f"The return value of {_get_func_signature(self._predict_func)} "

~/anaconda3/envs/venv/lib/python3.6/site-packages/fastNLP/core/tester.py in _data_forward(self, func, x)
    222         x = _build_args(func, **x)
--> 223         y = self._predict_func_wrapper(**x)
    224         return y
    225 

<ipython-input-6-5cd1f54b8245> in predict(self, input_sents)
     40 
     41     def predict(self, input_sents):
---> 42         x = self.forward(input_sents)[C.OUTPUT]
     43         return {C.OUTPUT: torch.argmax(x, 1)}
     44 

<ipython-input-6-5cd1f54b8245> in forward(self, input_sents)
     31         # input_sents [B, num_sents, seq-len] dtype long
     32         # target
---> 33         B, num_sents, seq_len = input_sents.size()
     34         input_sents = input_sents.view(-1, seq_len)  # flat
     35         words_embed = self.embed(input_sents)  # should be [B*num-sent, seqlen , word-dim]

TypeError: 'int' object is not callable

"""