fferflo / einx

Universal Tensor Operations in Einstein-Inspired Notation for Python.
https://einx.readthedocs.io/en/stable/
MIT License
311 stars 8 forks source link

einx.where auto moving float / int type inputs to correct device #7

Closed lucidrains closed 4 months ago

lucidrains commented 5 months ago

Hey Florian! Thank you again for einx; it continues to be very useful :pray:

So I ran into a small gotcha recently, best explained in a short code snippet

import torch
import einx

attn_similarities = torch.randn((2, 8, 1024, 1024)).cuda()
key_padding_mask = torch.randint(0, 1, (2, 1024)).bool().cuda()

# pytorch where

pytorch_masked = torch.where(
    einx.rearrange('b j -> b 1 1 j', key_padding_mask),
    attn_similarities,
    0.
)

# einx where

masked = einx.where(
    'b j, b h i j, -> b h i j',
    key_padding_mask,
    attn_similarities,
    0.
)

# fails since 0. does not get moved to the correct device

there are workarounds of course, so i'll leave this up to you whether you'd like to make it more aligned with pytorch's behavior!

fferflo commented 5 months ago

Hey Phil, it looks like Pytorch auto-moves regular floats/ints to GPU, but not if one converts them to a torch.Tensor first (which is done in einx to get a tensor object for the 0.). I tried a couple of ways to fix this, but am not happy with how they turned out.

I'll keep the issue open for now, and will let you know when I have a solution.

lucidrains commented 5 months ago

@fferflo yup, not a high priority issue!

fferflo commented 4 months ago

Just under lightspeed this time, but floats/ ints should be handled correctly with v0.2.1. In your example, the 0. argument is now forwarded directly to torch.where:

>>> print(einx.where('b j, b h i j, -> b h i j', key_padding_mask, attn_similarities, 0., graph=True))
import torch
def op0(i0, i1, i2):
    x0 = torch.reshape(i0, (2, 1, 1, 1024))
    x1 = torch.where(x0, i1, i2)
    return x1

Thanks for bringing it up!

lucidrains commented 4 months ago

@fferflo works perfectly, thank you Florian! 🙏