AnubhavGupta3377 / Text-Classification-Models-Pytorch

Implementation of State-of-the-art Text Classification Models in Pytorch
MIT License
482 stars 136 forks source link

TypeError: forward() missing 1 required positional argument: 'mask' #2

Closed BaiStone2017 closed 5 years ago

BaiStone2017 commented 5 years ago

use Model_Transformer when I run train.py,I got as follow: TypeError: forward() missing 1 required positional argument: 'mask'

Can you give me the reason?

AnubhavGupta3377 commented 5 years ago

Hi,

You need to make the argument "mask" of "forward()" in "Model_Transformer/encoder.py" as optional.

I'll update the code soon.

For now, you can just go to Model_Transformer/encoder.py and do following changes and it should work:

class Encoder(nn.Module):
    '''
    Transformer Encoder

    It is a stack of N layers.
    '''
    def __init__(self, layer, N):
        super(Encoder, self).__init__()
        self.layers = clones(layer, N)
        self.norm = LayerNorm(layer.size)

    def forward(self, x, mask=None):
        for layer in self.layers:
            x = layer(x, mask)
        return self.norm(x)

class EncoderLayer(nn.Module):
    '''
    An encoder layer

    Made up of self-attention and a feed forward layer.
    Each of these sublayers have residual and layer norm, implemented by SublayerOutput.
    '''
    def __init__(self, size, self_attn, feed_forward, dropout):
        super(EncoderLayer, self).__init__()
        self.self_attn = self_attn
        self.feed_forward = feed_forward
        self.sublayer_output = clones(SublayerOutput(size, dropout), 2)
        self.size = size

    def forward(self, x, mask=None):
        "Transformer Encoder"
        x = self.sublayer_output[0](x, lambda x: self.self_attn(x, x, x, mask)) # Encoder self-attention
        return self.sublayer_output[1](x, self.feed_forward)

That code was originally written for translation and I later updated it to work for classification. For classification, masking is not required. Hence, need to just ignore those arguments. Please let me know if that helps.

AnubhavGupta3377 commented 5 years ago

I've updated the code. Try running it now. It should work.

BaiStone2017 commented 5 years ago

issues as follow: Traceback (most recent call last): File "train.py", line 40, in train_loss, val_accuracy = model.run_epoch(dataset.train_iterator, dataset.val_iterator, i) File "/mnt/e/MachineLearningProjects/classification/Text-Classification-Models-Pytorch-master/Model_Transformer/model.py", line 76, in run_epoch y_pred = self.call(x) File "/mnt/e/linux/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(*input, *kwargs) File "/mnt/e/MachineLearningProjects/classification/Text-Classification-Models-Pytorch-master/Model_Transformer/model.py", line 41, in forward encoded_sents = self.encoder(embedded_sents) File "/mnt/e/linux/anaconda3/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in call result = self.forward(input, **kwargs) TypeError: forward() missing 1 required positional argument: 'mask'

Sorry for providing incomplete error information.And I have solved this through modifying the code as follow: in line 39 def forward(self, x,mask=None) in model.py It is not necessary to modify encoder.py. Thanks a lot!