Lightning-AI / lightning-thunder

Make PyTorch models up to 40% faster! Thunder is a source to source compiler for PyTorch. It enables using different hardware executors at once; across one or thousands of GPUs.
Apache License 2.0
1.2k stars 80 forks source link

`thunder.jit`ted `Tensor.masked_fill` and `Tensor.masked_fill_` return `torcdh.float32` tensor even when input is `torch.int64` #1083

Closed crcrpar closed 2 months ago

crcrpar commented 2 months ago

Note: If you have a model or program that is not supported yet but should be, please use the program coverage template.

🐛 Bug

As per title.

To Reproduce

import torch
import thunder

if __name__ == "__main__":
    with torch.device("cuda"):
        a = torch.randint(0, 64, size=(384, 16), dtype=torch.int64)
        b = a.clone().detach()
        mask = torch.randn(size=(384, 16)).to(torch.bool)

    def f(a, mask):
        return a.masked_fill(mask, 0.0)

    expected = f(a, mask)
    jitted = thunder.jit(f)
    actual = jitted(a, mask)

    print(f"# `Tensor.masked_fill`: {expected.dtype = }, {actual.dtype = }")
$ python a.py
# `Tensor.masked_fill`: expected.dtype = torch.int64, actual.dtype = torch.float32

Expected behavior

Returns the same dtype as the pytorch result.

Environment

Additional context