lucidrains / rotary-embedding-torch

Implementation of Rotary Embeddings, from the Roformer paper, in Pytorch
MIT License
573 stars 44 forks source link

Error caused by tensor-type seq_len #16

Closed cmunna0052 closed 11 months ago

cmunna0052 commented 11 months ago

I am running this package within a transformer model that is being jitted with torch.jit.trace_module. The unit test github workflow is throwing the error

E   beartype.roar.BeartypeCallHintParamViolation: Method 
rotary_embedding_torch.rotary_embedding_torch.RotaryEmbedding.forward() parameter seq_len="tensor(102)" 
violates type hint typing.Optional[int], as <protocol "torch.Tensor"> "tensor(102)" not int or <class "builtins.NoneType">.

I don't even see how a tensor type seq_len could even be possible here, given how it is defined from t.shape. Is there any reason why this would occur?

lucidrains commented 11 months ago

@cmunna0052 ahh, i'll just remove the runtime validation