pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 452 forks source link

grid sampler op need to register fp32 autocast #7305

Open lingzhi98 opened 2 months ago

lingzhi98 commented 2 months ago

🐛 Bug

grid_sampler can not run with auto mixed precision mode.

Steps to reproduce the behavior:

import torch
import torch.nn.functional as F
import numpy as np
import torch_xla
import torch_xla.core.xla_model as xm

xla_device = xm.xla_device()

sz = 5
input_arr = torch.from_numpy(np.arange(sz * sz).reshape(1, 1, sz, sz)).to(xla_device, dtype=torch.bfloat16)
indices = torch.from_numpy(np.array([-1, -1, -0.5, -0.5, 0,0, 0.5, 0.5, 1,1]).reshape(1, 1, 5, 2)).to(xla_device, dtype=torch.bfloat16)

with torch.amp.autocast("xla", dtype=torch.bfloat16):
  out = F.grid_sample(input_arr, indices)
xm.mark_step()
print(input_arr)
print(out)

RuntimeError: grid_sampler_2d_cpu not implemented for BFloat16

Expected behavior

Autocast the inputs of grid_sampler to fp32 datatype.

Environment

Additional context

torch.cuda set autocast mode of grid_sampler as promote, but torch xla can't. Due to torch_xla has no lowering of grid_sampler_2d/grid_sampler_3d, this op will fallback to torch.cpu implementation and no support of bfloat16. Maybe we should set autocast mode as fp32 firstly, and change it to promote until lowering is ready. By the way, grid sampler has no need of lowering, due to this op do nothing but only dispatch to gird_sampler_2d, grid_sampler_3d and cudnn_grid_sampler. https://github.com/pytorch/pytorch/blob/main/aten/src/ATen/native/GridSampler.cpp#L1046

JackCaoG commented 2 months ago

yea I confirmed that we don't have a lowering for

Counter: aten::grid_sampler_2d
  Value: 1

I guess the solution here would be to actually lower this op.

JackCaoG commented 2 months ago

also found https://github.com/pytorch/xla/issues/6581

lingzhi98 commented 2 months ago

Lowering grid sampler 2d is not enough, we should lower grid sampler 3d also to fully support grid sampler op. That's why I suggest set autocast mode as fp32 firstly. Or there exists plan to support grid sampler 3d? I dont find it until now.

ManfeiBai commented 2 months ago

Hi, @wonjoolee95, is that ok to assign this ticket to you?