pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.44k stars 451 forks source link

xla gpu train ResizeBicubic is not supported #7267

Open mars1248 opened 2 months ago

mars1248 commented 2 months ago

❓ Questions and Help

I use this code

import unittest
import torch
import torch.nn.functional as F
import torch_xla
import torch_xla.core.xla_model as xm
class TestInterpolate(unittest.TestCase):

    def test_upsample(self):
        # 创建一个大小为 (2, 2) 的2D tensor
        input_tensor = torch.arange(1, 5).view(1, 1, 2, 2).float().to(xm.xla_device())
        # 预期的 scale_factor 将大小变为原来的两倍
        scale_factor = 2

        # 应用 interpolate 进行上采样
        output_tensor = F.interpolate(input_tensor, scale_factor=scale_factor, mode='bicubic')
        print(output_tensor)
        # 确定输出 tensor 的大小是否是我们预期的大小
        # 预期输出 tensor 的形状应为(1, 1, 4, 4)
        expected_output_size = torch.Size([1, 1, 4, 4])

        self.assertEqual(output_tensor.size(), expected_output_size)

if __name__ == '__main__':
    unittest.main()

and then report this bug

======================================================================
ERROR: test_upsample (__main__.TestInterpolate)
----------------------------------------------------------------------
Traceback (most recent call last):
  File "resampler.py", line 16, in test_upsample
    print(output_tensor)
  File "/opt/conda/lib/python3.8/site-packages/torch/_tensor.py", line 473, in __repr__
    return torch._tensor_str._str(self, tensor_contents=tensor_contents)
  File "/opt/conda/lib/python3.8/site-packages/torch/_tensor_str.py", line 697, in _str
    return _str_intern(self, tensor_contents=tensor_contents)
  File "/opt/conda/lib/python3.8/site-packages/torch/_tensor_str.py", line 432, in _str_intern
    self = self.to("cpu")
RuntimeError: Error while lowering: [] aten::upsample_bicubic2d, xla_shape=f32[1,1,4,4]{3,2,1,0}, dynamic_dims: (), output_size=(4, 4), align_corners=0
Error: torch_xla/csrc/resize_ops.cpp:299 : Resize kernel: ResizeBicubic is not supported
JackCaoG commented 2 months ago

Yea I don't think this is a GPU specified issue, we probally just need to lower this op. cc @wonjoolee95 let's just add this to our todo.

mars1248 commented 2 months ago

@wonjoolee95 Can you support it now?

wonjoolee95 commented 1 month ago

Hey @vanbasten23, is this something we can add as an onboarding task?

JackCaoG commented 1 month ago

I can you can see if we can use upstream decomp directly https://github.com/pytorch/pytorch/blob/c6cce976b25708e66ca2d7d7da1cd27c942e7a22/torch/_decomp/decompositions.py#L4422 similar to what I did in https://github.com/pytorch/xla/commit/5e51ca2365c2955bebdaaa62dfccbab3caa6cbda