Closed Cyh294 closed 3 years ago
Hi, PyTorch is telling you here that the mask (here represented by the variable attention_mask
) should be of type Bool
, but is Byte
. I encourage you to get deeper into this error to understand why the mask is of type Byte
, I can't help you with just this traceback. If you have more info, you can always post it here.
thanks. I have found out the reason for the error, which has something to do with version of python
btw, I wonder what is the reason for changing softmax to sigmoid? looking forward to your reply :)
I'm currently working on a regression problem, softmax
is usually applied to classification, so it's not a good fit here. I'm using sigmoid
, but it could be replaced with any other activation function, such as ReLU
or its variations. Feel free to play around with the ones implemented in pytorch, it should be easy to plug !
thanks for your code. when running the code 'training.ipynb' ,I got an RuntimeError as follow: [Epoch 1/30]: 0%| | 0/5000 [00:00<?, ?it/s] Traceback (most recent call last):
File "C:\Anaconda3\envs\pytorch\transformer-master\training.py", line 111, in
netout = net(x.to(device))
File "C:\Anaconda3\envs\py37\lib\site-packages\torch\nn\modules\module.py", line 493, in call result = self.forward(*input, **kwargs)
File "C:\Anaconda3\envs\pytorch\transformer-master\tst\transformer.py", line 131, in forward encoding = layer(encoding)
File "C:\Anaconda3\envs\py37\lib\site-packages\torch\nn\modules\module.py", line 493, in call result = self.forward(*input, **kwargs)
File "C:\Anaconda3\envs\pytorch\transformer-master\tst\encoder.py", line 86, in forward x = self._selfAttention(query=x, key=x, value=x)
File "C:\Anaconda3\envs\py37\lib\site-packages\torch\nn\modules\module.py", line 493, in call result = self.forward(*input, **kwargs)
File "C:\Anaconda3\envs\pytorch\transformer-master\tst\multiHeadAttention.py", line 91, in forward self._scores = self._scores.masked_fill(attention_mask, float('-inf'))
RuntimeError: Expected object of scalar type Byte but got scalar type Bool for argument #2 'mask'
can you tell me how to fix it?