Closed eigenvectorBazuz closed 1 month ago
This is because the QK^T dtype will be converted to fp16 in the code, and I will submit a PR later
Merged @Uwwal's changes! Let me know if this fixes your issue @eigenvectorBazuz, otherwise I'll take a look at it more deeply later this week.
I tried to run the given small examples in a new conda env with triton 3.0.0 and torch and encountered two problems:
(1)
TypeError: apply() takes no keyword arguments
Which I solved by sending the arguments unnamed (i.e. mask instead of mask=mask etc.) But then I ran into a thornier problem:(2)
This I am at a loss how to solve on my own...