DeepLink-org / DIOPI

BSD 3-Clause "New" or "Revised" License
68 stars 34 forks source link

[Fix] fix ascend memory format config #1317

Closed DoorKickers closed 3 months ago

DoorKickers commented 3 months ago

Motivation and Context

This pr is to fix a issue introduced by commit 1647d7e9

Description

After commit 1647d7e9, the following basic training code will not work

import torch
import torch_dipu

device = torch.device("cuda")

m = torch.nn.Conv2d(3, 32, kernel_size=3).to(device)
input = torch.randn(1, 3, 255, 255).to(device)
output = m.forward(input)

print(output)

It will generate a runtime error

 File "/python3.9/site-packages/torch/nn/modules/conv.py", line 459, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: unsupported memory format Preserve
Exception raised from empty_tensor_restride at ../c10/core/TensorImpl.h:2378 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7fc029d29047 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7fc029ce423d in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libc10.so)
frame #2: <unknown function> + 0x1768998 (0x7fc0134a2998 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #3: <unknown function> + 0x176bbe6 (0x7fc0134a5be6 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #4: at::detail::empty_generic(c10::ArrayRef<long>, c10::Allocator*, c10::DispatchKeySet, c10::ScalarType, c10::optional<c10::MemoryFormat>) + 0x14 (0x7fc01349ddb4 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #5: dipu::native::dipu_aten::empty(c10::ArrayRef<long>, c10::optional<c10::ScalarType>, c10::optional<c10::Layout>, c10::optional<c10::Device>, c10::optional<bool>, c10::optional<c10::MemoryFormat>) + 0xea (0x7fc00a89f90a in /root/workspace/deeplink.framework/dipu/torch_dipu/libtorch_dipu.so)
frame #6: <unknown function> + 0x1f9a95 (0x7fc00a89fa95 in /root/workspace/deeplink.framework/dipu/torch_dipu/libtorch_dipu.so)
frame #7: dipu::native::dipu_convolution_overrideable(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long) + 0x247 (0x7fc00aac04a7 in /root/workspace/deeplink.framework/dipu/torch_dipu/libtorch_dipu.so)
frame #8: c10::impl::wrap_kernel_functor_unboxed_<c10::impl::detail::WrapFunctionIntoFunctor_<c10::CompileTimeFunctionPointer<at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long), &dipu::native::dipu_convolution_overrideable>, at::Tensor, c10::guts::typelist::typelist<at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long> >, at::Tensor (at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long)>::call(c10::OperatorKernel*, c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long) + 0x52 (0x7fc00a904012 in /root/workspace/deeplink.framework/dipu/torch_dipu/libtorch_dipu.so)
frame #9: at::_ops::convolution_overrideable::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long) + 0x236 (0x7fc0146a9636 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #10: at::native::_convolution(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long, bool, bool, bool, bool) + 0xdda (0x7fc0138747ba in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #11: <unknown function> + 0x2bf3dc6 (0x7fc01492ddc6 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #12: <unknown function> + 0x2bf3e47 (0x7fc01492de47 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #13: at::_ops::_convolution::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<long>, bool, c10::ArrayRef<c10::SymInt>, long, bool, bool, bool, bool) + 0x29b (0x7fc0140d2dfb in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #14: at::native::convolution(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<long>, c10::ArrayRef<long>, bool, c10::ArrayRef<long>, long) + 0x21d (0x7fc0138693fd in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #15: <unknown function> + 0x2bf39a5 (0x7fc01492d9a5 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #16: <unknown function> + 0x2bf3a0f (0x7fc01492da0f in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #17: at::_ops::convolution::redispatch(c10::DispatchKeySet, at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<long>, bool, c10::ArrayRef<c10::SymInt>, long) + 0x12f (0x7fc0140995ef in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #18: <unknown function> + 0x40e6353 (0x7fc015e20353 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #19: <unknown function> + 0x40e7247 (0x7fc015e21247 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #20: at::_ops::convolution::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<long>, bool, c10::ArrayRef<c10::SymInt>, long) + 0x223 (0x7fc0140d2203 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #21: at::native::conv2d_symint(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<long>, long) + 0x1de (0x7fc01386c95e in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #22: <unknown function> + 0x2dafb81 (0x7fc014ae9b81 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #23: at::_ops::conv2d::call(at::Tensor const&, at::Tensor const&, c10::optional<at::Tensor> const&, c10::ArrayRef<long>, c10::ArrayRef<c10::SymInt>, c10::ArrayRef<long>, long) + 0x204 (0x7fc0146aa3c4 in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_cpu.so)
frame #24: <unknown function> + 0x6005ec (0x7fc028f685ec in /opt/miniconda3/envs/zlt_dipu_dev_test_py39/lib/python3.9/site-packages/torch/lib/libtorch_python.so)
frame #25: python3() [0x507387]
<omitting python frames>
frame #28: python3() [0x4f8123]
frame #29: python3() [0x504f81]
frame #31: python3() [0x4f8123]
frame #32: python3() [0x504f81]
frame #34: python3() [0x4e6b2a]
frame #38: python3() [0x5c1dc7]
frame #39: python3() [0x5bddd0]
frame #40: python3() [0x45674e]
frame #44: <unknown function> + 0x29d90 (0x7fc02a438d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #45: __libc_start_main + 0x80 (0x7fc02a438e40 in /lib/x86_64-linux-gnu/libc.so.6)
frame #46: python3() [0x5885ce]

And this bug caused all models with conv2d layers to fail to train properly !!!

After problem diagnosis, I found that the bug was caused by the code generated autogen_diopi_wrapper.py

The error code it generated is as follows:

//  convolution_overrideable(Tensor input, Tensor weight, Tensor? bias, int[] stride, int[] padding, int[] dilation, bool transposed, int[] output_padding, int groups) -> Tensor
at::Tensor dipu_convolution_overrideable(const at::Tensor& input, const at::Tensor& weight, const c10::optional<at::Tensor>& bias, at::IntArrayRef stride, at::IntArrayRef padding, at::IntArrayRef dilation, bool transposed, at::IntArrayRef output_padding, int64_t groups) {
  c10::OptionalDeviceGuard guard(at::device_of(input));
  dipu::profile::RecordBlockCreator _(__FUNCTION__);
  if (dumpOpArgLevel() > 0) {
    printf("--%-50s %-30s \n", "[convolution_overrideable]:", "diopiConvolution2d");
  }
  int64_t batch_size = input.size(0);
  int64_t height = input.size(2);
  int64_t width = input.size(3);
  int64_t out_channel = weight.size(0);
  auto kernel_size = weight.sizes().slice(2);
  int64_t out_height = (height + 2 * padding[0] - dilation[0] * (kernel_size[0] - 1) - 1) / stride[0] + 1;
  int64_t out_width = (width + 2 * padding[1] - dilation[1] * (kernel_size[1] - 1) - 1) / stride[1] + 1;
  c10::SmallVector<int64_t, 4> output_size = {batch_size, out_channel, out_height, out_width};
  at::Tensor out = nodispatch::empty(output_size, input.options(), at::MemoryFormat::Preserve);

The at::MemoryFormat::Preserve is represent for that tensor remains consistent across different operations or transformations. And it is meaningless in this line, as user has to determine the specific memory layout of the empty tensor that just created

at::Tensor out = nodispatch::empty(output_size, input.options(), at::MemoryFormat::Preserve);

In torch's source code, there is a logic to check this error:

  void empty_tensor_restride(MemoryFormat memory_format) {
    if (has_symbolic_sizes_strides_) {
      empty_tensor_restride_symint(memory_format);
      return;
    }
#ifdef DEBUG
    TORCH_INTERNAL_ASSERT(
        compute_numel() == numel_,
        "If you are seeing this error, that means empty_tensor_restride was "
        "called before setting correct numel");
#endif
    switch (memory_format) {
      case MemoryFormat::Contiguous: {
        // dim_ is a virtual call, don't repeat it
        const auto dim_ = dim();
        sizes_and_strides_.resize(dim_);
        if (dim_ > 0) {
          bool overflowed = false;
          const auto last_idx = dim_ - 1;
          sizes_and_strides_.stride_at_unchecked(last_idx) = 1;
          for (auto i = last_idx - 1; i >= 0; --i) {
            overflowed |= c10::mul_overflows(
                sizes_and_strides_.stride_at_unchecked(i + 1),
                std::max<int64_t>(
                    sizes_and_strides_.size_at_unchecked(i + 1), 1),
                std::addressof(sizes_and_strides_.stride_at_unchecked(i)));
          }
          TORCH_CHECK(!overflowed, "Stride calculation overflowed");
        }
        break;
      }
      case MemoryFormat::ChannelsLast: {
        TORCH_CHECK(
            dim() == 4, "required rank 4 tensor to use channels_last format");
        set_sizes_and_strides(sizes(), get_channels_last_strides_2d(sizes()));
        break;
      }
      case MemoryFormat::ChannelsLast3d: {
        TORCH_CHECK(
            dim() == 5,
            "required rank 5 tensor to use channels_last_3d format");
        set_sizes_and_strides(sizes(), get_channels_last_strides_3d(sizes()));
        break;
      }
      case MemoryFormat::Preserve:
        TORCH_CHECK(false, "unsupported memory format ", memory_format);
        // Cleaning warning messages, no need to break as TORCH_CHECK(false)
        // terminates flow.
        // break;
      case MemoryFormat::NumOptions:
        TORCH_INTERNAL_ASSERT(false, "invalid memory format ", memory_format);
    }
    // recompute contiguous flag, as currently NHWC/NCHW flags are not mutually
    // exclusive see #24090
    refresh_contiguous();
  }

Further debugging revealed that the code generation errors were caused by a combination of issues with the code generation logic and the ascend convert config file.

In op_memory_format_converter.py, there is a logic to get memory format param from config file:

class ConvertConfig(object):
    # This class is used to load and parse the convert_config.yaml
    def __init__(self, config_yaml):
        self.convert_dict = dict()
        self.convert_config_yaml = config_yaml
        self.default_layout = "empty"
        assert isinstance(config_yaml, list)
        for config in config_yaml:
            assert isinstance(config, dict)
            for interface in config.keys():
                if interface == "common_config":
                    detail = config[interface]
                    assert isinstance(detail, dict)
                    if "layout" in detail:
                        self.default_layout = self.layout2memoryformat(detail["layout"])
                    pass
                    # may add common behavior
            for interface in config.keys():
                if interface != "common_config":
                    self.convert_dict.setdefault(interface, dict())
                    detail = config[interface]
                    assert isinstance(detail, dict)
                    if "layout" in detail:
                        self.convert_dict[interface]["layout"] = (
                            self.layout2memoryformat(detail["layout"])
                        )

    def layout2memoryformat(self, layout):
        # used when pasing convert_config.yaml, return the memory format based on NCHW/NHWC and other layout.
        assert isinstance(layout, str)
        if "NCHW" in layout:
            return "contiguous"
        if "NLC" in layout:
            return "channellast"
        if "NHWC" in layout:
            return "channellast"
        if "NDHWC" in layout:
            return "channellast"
        return "preserve"

    def interface2memoryformat(self, interface):
        # return the prefered memory format based on the DIOPI interface.
        interface_stripped = interface.strip().split("(")[0]
        if (interface_stripped not in self.convert_dict) or (
            "layout" not in self.convert_dict[interface_stripped]
        ):
            return self.default_layout
        else:
            return self.convert_dict[interface_stripped]["layout"]

It will read op memory format param from ascend/convert_config.yaml, and if an op is not specified, it will use the default memory format from common_config in the yaml file.

- common_config:
    layout: NCHW

- diopiAdamW:
    dtype: (float64)->float32
    layout: ND

- diopiSoftmax:
    dtype: (float64)->float32
    layout: ND

- diopiBaddbmm:
    dtype: (float64)->float32

- diopiBaddbmmInp:
    dtype: (float64)->float32

- diopiSoftmaxBackward:
    dtype: (float64)->float32
    layout: ND

- diopiLogSoftmax:
    dtype: (float64)->float32
    layout: ND

- diopiLogSoftmaxBackward:
    dtype: (float64)->float32
    layout: ND

- diopiGelu:
    dtype: (float64)->float32

- diopiGeluBackward:
    dtype: (float64)->float32

- diopiConvolution2d:
    dtype: (float64)->float16

- diopiConvolution2dBackward:
    dtype: (float64)->float16

And the problem is here, after add lines "layout: ND", it will make the code generator not using common_config, which is contiguous. Instead, it will generate conv2d code with memory_format Preserve, which caused this bug.

In order to test and run ascend CI train-one-iter for models contain conv2d, I need to delete these lines "layout: ND", as a preliminary solution to this bug. @jingguo-st

Use cases (Optional)

BC-breaking (Optional)

Checklist

Before PR:

After PR: