I am running this package within a transformer model that is being jitted with torch.jit.trace_module. The unit test github workflow is throwing the error
E beartype.roar.BeartypeCallHintParamViolation: Method
rotary_embedding_torch.rotary_embedding_torch.RotaryEmbedding.forward() parameter seq_len="tensor(102)"
violates type hint typing.Optional[int], as <protocol "torch.Tensor"> "tensor(102)" not int or <class "builtins.NoneType">.
I don't even see how a tensor type seq_len could even be possible here, given how it is defined from t.shape. Is there any reason why this would occur?
I am running this package within a transformer model that is being jitted with torch.jit.trace_module. The unit test github workflow is throwing the error
I don't even see how a tensor type seq_len could even be possible here, given how it is defined from t.shape. Is there any reason why this would occur?