lucidrains / routing-transformer

Fully featured implementation of Routing Transformer
MIT License
282 stars 29 forks source link

recieves_context cause tensor mismatch error #25

Closed WeForgot closed 3 years ago

WeForgot commented 3 years ago

Hi all,

I am sorry if this is a dumb question but I am running into an issue that I can't seem to solve. My model is a generative encoder/decoder model where the encoder is a vision transformer and the decoder is a routing transformer (this repo 😄 ). The output is continuous valued so I cannot use the autoregressive wrapper.

For the longest time I used this without passing in "recieves_context" which obviously was silly and kind of circumventing the whole idea of having the ViT head. When I used the flag though I get the error below.

Traceback (most recent call last):
  File "main.py", line 327, in <module>
    main(args)
  File "main.py", line 196, in main
    layer_loss, color_loss, position_loss, aux_loss = model(feature, pad_label, mask=pad_mask, use_activations=use_activations)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\jnew2\source\repos\NSA\model\model.py", line 205, in forward
    x, aux_loss = self.decoder(y, context=context, input_mask=feature_mask)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\routing_transformer.py", line 645, in forward
    x, loss = self.layers(x, **kwargs)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\reversible.py", line 171, in forward
    res, loss = cast_return(f(x, **f_args))
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\routing_transformer.py", line 134, in forward
    return self.fn(x, **kwargs)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\routing_transformer.py", line 558, in forward
    global_out, loss = self.global_attn(q, k, v, query_mask = input_mask, key_mask = context_mask)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\torch\nn\modules\module.py", line 1015, in _call_impl
    return forward_call(*input, **kwargs)
  File "c:\Users\jnew2\anaconda3\envs\NSA\lib\site-packages\routing_transformer\routing_transformer.py", line 374, in forward
    dists, aux_loss = self.kmeans(torch.cat((q, k), dim=2), update_kmeans)
RuntimeError: Tensors must have same number of dimensions: got 3 and 4

I ran through the code and I can definitely see where it is breaking but I really have no idea where to even start with alleviating that. For what it is worth the dims of everything are consistent:

x = torch.Size([8, 230, 64]) context = torch.Size([8, 64]) input_mask = torch.Size([8, 230])

and parameters I am initializing with for the routing transformer are:

RoutingTransformer(dim = 64, depth = 2, max_seq_len = 256, heads = 16, ff_glu = True, use_scale_norm = True, causal = True, receives_context=True)

Once again, this is probably a large issue with me misunderstanding the code but it works with other transformer architectures and I am not sure where to go.

Cheers!

WeForgot commented 3 years ago

I hate leaving issues and then nearly immediately being like "oh wait..." Turns out I forgot that the the ViT library outputs in using the class tokens rather than full hidden reps. I switched to one that returned those and we are in business. Sorry about that, closing the issue.