NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
260 stars 51 forks source link

inconsistent parallelization found between TV34 and TV4. Producer is required to be in Global Memory based on parallelization strategy. #3228

Open tfogal opened 1 day ago

tfogal commented 1 day ago
Traceback (most recent call last):
  File "/home/tfogal/dev/tfx-test-cases/phi3/big.py", line 71, in <module>
    loss.backward()
  File "/usr/local/lib/python3.10/dist-packages/torch/_tensor.py", line 581, in backward
    torch.autograd.backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/__init__.py", line 347, in backward
    _engine_run_backward(
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/graph.py", line 825, in _engine_run_backward
    return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 307, in apply
    return user_fn(self, *args)
  File "/usr/local/lib/python3.10/dist-packages/torch/autograd/function.py", line 600, in wrapper
    outputs = fn(ctx, *args)
  File "/home/tfogal/dev/thunder/thunder/executors/torch_autograd.py", line 96, in backward
    grads = ctx.compiled_backward([saved_tensors_list, ctx.saved_other], args)
  File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 116, in decorate_context
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "/usr/local/lib/python3.10/dist-packages/torch/amp/autocast_mode.py", line 44, in decorate_autocast
    return func(*args, **kwargs)
  File "thunder.backward_fn_24", line 29, in backward_fn
  File "/home/tfogal/dev/thunder/thunder/executors/nvfuserex_impl.py", line 456, in __call__
    return fd.execute(args, **kwargs)
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 235, in execute
    return self._execute(
RuntimeError:  INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp":800, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV34 (T34_l___bfloat[ iblockIdx.x300{( ceilDiv(( ceilDiv(8192, 8) ), blockDim.x) )}, iblockIdx.y302{( ceilDiv(128, 1) )}, iUS303{1}, iS299{8}, ithreadIdx.x301{blockDim.x} ]) and TV4(T4_l___bfloat[ iblockIdx.x216{( ceilDiv(( ceilDiv(8192, 8) ), blockDim.x) )}, iblockIdx.y219{( ceilDiv(( 1 * 128 ), 1) )}, iUS220{1}, iS215{8}, ithreadIdx.x217{blockDim.x} ] ca_pos( 5 )). Producer is required to be in Global Memory based on parallelization strategy. RAW flags: (blockIdx.y)
Exception raised from SyncMap at /opt/pytorch/nvfuser/csrc/device_lower/analysis/sync_information.cpp:800 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x77816bb54cff in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x77816becb323 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x2785 (0x77816bda8ec5 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x3b6bbb (0x77816bdd8bbb in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0xdec (0x77816bddaa4c in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType, long, long, long, long) + 0x9f1 (0x77816bf3f531 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x6d88da (0x77816c0fa8da in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x492 (0x77816c102912 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x4ba (0x77816c10d39a in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x796 (0x77816c32f196 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x1bb13e (0x77816bbdd13e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: <unknown function> + 0x236a9f (0x77816bc58a9f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0x2c7250 (0x77816bce9250 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #37: torch::autograd::PyNode::apply(std::vector<at::Tensor, std::allocator<at::Tensor> >&&) + 0x95 (0x7782d87b7d35 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #38: <unknown function> + 0x4fb8dab (0x7782cfd2ddab in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame #39: torch::autograd::Engine::evaluate_function(std::shared_ptr<torch::autograd::GraphTask>&, torch::autograd::Node*, torch::autograd::InputBuffer&, std::shared_ptr<torch::autograd::ReadyQueue> const&) + 0xdde (0x7782cfd280de in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame #40: torch::autograd::Engine::thread_main(std::shared_ptr<torch::autograd::GraphTask> const&) + 0x56c (0x7782cfd292fc in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame #41: torch::autograd::Engine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x2a9 (0x7782cfd21999 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cpu.so)
frame #42: torch::autograd::python::PythonEngine::thread_init(int, std::shared_ptr<torch::autograd::ReadyQueue> const&, bool) + 0x75 (0x7782d87b2855 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_python.so)
frame #43: <unknown function> + 0xdc253 (0x7782e5db8253 in /lib/x86_64-linux-gnu/libstdc++.so.6)
frame #44: <unknown function> + 0x94ac3 (0x7782e5f9cac3 in /lib/x86_64-linux-gnu/libc.so.6)
frame #45: <unknown function> + 0x126850 (0x7782e602e850 in /lib/x86_64-linux-gnu/libc.so.6)
# CUDA devices:
#  0: NVIDIA RTX 6000 Ada Generation
# torch version: 2.6.0a0+git75f141b
# cuda version: 12.6
# nvfuser version: 0.2.12+gitc52063f
import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id20(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[1, 128, 16384], contiguity=[None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
    T1 = fd.define_tensor(shape=[128, 8192], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T14 = fd.ops.slice(T0, start_indices=[0, 0, 0], end_indices=[1, 128, 8192], strides=[1, 1, 1])
    T15 = fd.ops.cast(T14, dtype=DataType.Float)
    T20 = fd.ops.reshape(T1, new_shape=[1, 128, 8192])
    T33 = fd.ops.slice(T0, start_indices=[0, 0, 8192], end_indices=[1, 128, 16384], strides=[1, 1, 1])
    T34 = fd.ops.neg(T15)
    T35 = fd.ops.cast(T20, dtype=DataType.Float)
    T36 = fd.ops.cast(T33, dtype=DataType.Float)
    T37 = fd.ops.exp(T34)
    T38 = fd.ops.mul(T36, T35)
    S39 = fd.define_scalar(1.00000, dtype=DataType.Double)
    T40 = fd.ops.add(S39, T37)
    T41 = fd.ops.mul(T15, T38)
    T42 = fd.ops.reciprocal(T40)
    T43 = fd.ops.neg(T41)
    T44 = fd.ops.mul(T43, T42)
    T45 = fd.ops.mul(T44, T42)
    T46 = fd.ops.mul(T45, T37)
    T47 = fd.ops.neg(T46)
    T48 = fd.ops.mul(T42, T38)
    T49 = fd.ops.mul(T15, T42)
    T50 = fd.ops.add(T48, T47)
    T51 = fd.ops.mul(T49, T35)
    T52 = fd.ops.cast(T50, dtype=DataType.BFloat16)
    T53 = fd.ops.cast(T51, dtype=DataType.BFloat16)
    S54 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T55 = fd.ops.pad(T52, [0, 8192, 0, 0, 0, 0], S54)
    S56 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T57 = fd.ops.pad(T53, [8192, 0, 0, 0, 0, 0], S56)
    T58 = fd.ops.cast(T55, dtype=DataType.Float)
    T59 = fd.ops.cast(T57, dtype=DataType.Float)
    T60 = fd.ops.add(T59, T58)
    T61 = fd.ops.cast(T60, dtype=DataType.BFloat16)
    T65 = fd.ops.reshape(T61, new_shape=[128, 16384])
    T66 = fd.ops.permute(T65, dims=[1, 0])
    fd.add_output(T65)
    fd.add_output(T66)

with FusionDefinition() as fd:
    nvfuser_fusion_id20(fd)

inputs = [
    torch.randn(2097152, dtype=torch.bfloat16, device='cuda:0').as_strided((1, 128, 16384), (2097152, 16384, 1)),
    torch.randn(1048576, dtype=torch.bfloat16, device='cuda:0').as_strided((128, 8192), (8192, 1)),
]
fd.execute(inputs)
jacobhinkle commented 1 day ago

Seems similar to #1757. See also #1728. Note that if the input T1 is already broadcasted then we do not hit this error.

naoyam commented 20 hours ago

@tfogal Can you try the repro with NVFUSER_ENABLE=id_model(all)? It seems to work for me.