Renovamen / Text-Classification

PyTorch implementation of some text classification models (HAN, fastText, BiLSTM-Attention, TextCNN, Transformer) | 文本分类
MIT License
138 stars 29 forks source link

python train.py --config ./configs/transformer.yaml #2

Closed VinACE closed 3 years ago

VinACE commented 3 years ago

Loading embeddings: 400000it [00:55, 7234.47it/s] Saving vectors to /home/a/Text-Classification/data/outputs/ag_news/sents/glove.6B.300d.txt.pth.tar Traceback (most recent call last): File "train.py", line 73, in trainer = set_trainer(config) File "train.py", line 28, in set_trainer model = models.setup( File "/home/a/Text-Classification/models/init.py", line 83, in setup model = Transformer( File "/home/a/Text-Classification/models/Transformer/transformer.py", line 53, in init self.encoder = EncoderLayer(d_model, n_heads, hidden_size, dropout) File "/home/a/Text-Classification/models/Transformer/encoder_layer.py", line 22, in init self.attention = MultiHeadAttention(d_model, n_heads, dropout) File "/home/a/Text-Classification/models/Transformer/attention.py", line 56, in init assert d_model % n_heads == 0 AssertionError

VinACE commented 3 years ago

Had class labels encoded from 0, actually in this code is expecting class labels from "1" Not sure If the error is related to class labels starting from 0, it should start from "1"

Renovamen commented 3 years ago

Sorry for the late reply. This error occured because emb_size % n_heads was not 0 in configs/ag_news/transformer.yaml, and has been fixed in https://github.com/Renovamen/Text-Classification/commit/0ab407beca8271f7bda98585a8b5b0e0f2de87f7. Thanks!

Also, it seems that the way to encode the labels (start from 0) in this code is ok. Label ids in the dataset are start from 1, however, when preprocessing the dataset, this code maps the original label_id to label_id - 1:

https://github.com/Renovamen/Text-Classification/blob/0ab407beca8271f7bda98585a8b5b0e0f2de87f7/datasets/preprocess/sentence.py#L66

So in the preprocessed data, label ids are start from 0.