lucidrains / routing-transformer

Fully featured implementation of Routing Transformer
MIT License
282 stars 29 forks source link

When running the RoutingTransformerLM's example, there is an error that the tensor dimension does not match #15

Closed JunZhan2000 closed 3 years ago

JunZhan2000 commented 3 years ago

Hello, thanks for your work. When I running the RoutingTransformerLM's example, I encountered a tensor dimension mismatch error:

Traceback (most recent call last): File "/data/zj/code/routing-transformer/simple-language-model.py", line 39, in y, aux_loss = model(x, input_mask = input_mask) # (1, 8192, 20000) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(*input, *kwargs) File "/usr/local/lib/python3.6/dist-packages/routing_transformer/routing_transformer.py", line 624, in forward x = self.norm(x) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/module.py", line 550, in call result = self.forward(input, *kwargs) File "/usr/local/lib/python3.6/dist-packages/torch/nn/modules/normalization.py", line 153, in forward input, self.normalized_shape, self.weight, self.bias, self.eps) File "/usr/local/lib/python3.6/dist-packages/torch/nn/functional.py", line 1956, in layer_norm torch.backends.cudnn.enabled) RuntimeError: Given normalized_shape=[512], expected input with shape [, 512], but got input of size[1, 8192, 128]

I roughly read the source code, there seems to be a bug in the class RoutingTransformerLM of routing_transformer.py. According to the parameters given in the example, the dim is 512, the emb_dim is 128, and the parameter return_embeddings's default is False. Then after the routing-transformer module, the dimension of x is (n, emb_dim), which is (2, 128), and the dimension of the norm module‘s parameter is (n, dim), which is (2, 512), at this time Will report an error. I don’t know what the two parameters emb_dim and dim do. I will continue to read the code, hope you can pay attention to this issue.

lucidrains commented 3 years ago

@guokr233 oops, you are right! fixed in the latest https://github.com/lucidrains/routing-transformer/releases/tag/1.4.0