import torch
torch._C._jit_set_nvfuser_enabled(True)
torch._C._jit_set_texpr_fuser_enabled(False)
torch._C._jit_set_profiling_executor(True)
torch._C._jit_set_profiling_mode(True)
x = torch.tensor(0, device='cuda:0', dtype=torch.float16)
y = torch.tensor(0, dtype=torch.float32)
z = torch.tensor([0], device='cuda:0', dtype=torch.float16)
q = torch.tensor(0, device='cuda:0', dtype=torch.float)
print("========== case 1 ==========")
o = x * y + z
print("Eager:", o.dtype)
@torch.jit.script
def f(x, y, z):
return x * y + z
o = f(x, y, z)
print("JIT 1:", o.dtype)
o = f(x, y, z)
print("JIT 2:", o.dtype)
o = f(x, y, z)
print("JIT 3:", o.dtype)
🐛 Describe the bug
I am seeing