Closed tanmay-bakshi closed 1 year ago
The input token IDs should be long, not bfloat16, when using the Triton attention implementation, as they're fed to an embedding layer.
long
bfloat16
Thanks for catching this, Tanmay! Merging your fix here, and mirroring also to our HuggingFace repo.
The input token IDs should be
long
, notbfloat16
, when using the Triton attention implementation, as they're fed to an embedding layer.