Closed lucidrains closed 4 months ago
Hey Phil, it looks like Pytorch auto-moves regular floats/ints to GPU, but not if one converts them to a torch.Tensor first (which is done in einx to get a tensor object for the 0.
). I tried a couple of ways to fix this, but am not happy with how they turned out.
I'll keep the issue open for now, and will let you know when I have a solution.
@fferflo yup, not a high priority issue!
Just under lightspeed this time, but floats/ ints should be handled correctly with v0.2.1. In your example, the 0.
argument is now forwarded directly to torch.where
:
>>> print(einx.where('b j, b h i j, -> b h i j', key_padding_mask, attn_similarities, 0., graph=True))
import torch
def op0(i0, i1, i2):
x0 = torch.reshape(i0, (2, 1, 1, 1024))
x1 = torch.where(x0, i1, i2)
return x1
Thanks for bringing it up!
@fferflo works perfectly, thank you Florian! 🙏
Hey Florian! Thank you again for
einx
; it continues to be very useful :pray:So I ran into a small gotcha recently, best explained in a short code snippet
there are workarounds of course, so i'll leave this up to you whether you'd like to make it more aligned with pytorch's behavior!