Open crcrpar opened 2 days ago
# CUDA devices: # 0: NVIDIA RTX 6000 Ada Generation # torch version: 2.6.0a0+git62eea62 # cuda version: 12.6 # nvfuser version: 0.2.23+git7b92716 import torch from nvfuser import FusionDefinition, DataType def nvfuser_fusion_id4(fd : FusionDefinition) -> None : T0 = fd.define_tensor(shape=[64], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0]) T1 = fd.define_tensor(shape=[16, 64], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0]) T5 = fd.ops.broadcast_in_dim(T0, shape=[16, 64], broadcast_dims=[1]) T6 = fd.ops.cast(T1, dtype=DataType.Float) T7 = fd.ops.cast(T5, dtype=DataType.Float) T8 = fd.ops.add(T6, T7) T9 = fd.ops.mul(T8, T8) T10 = fd.ops.mul(T9, T8) S11 = fd.define_scalar(0.500000, dtype=DataType.Double) T12 = fd.ops.mul(S11, T8) S13 = fd.define_scalar(0.0447150, dtype=DataType.Double) T14 = fd.ops.mul(S13, T10) T15 = fd.ops.add(T8, T14) S16 = fd.define_scalar(0.797885, dtype=DataType.Double) T17 = fd.ops.mul(S16, T15) T18 = fd.ops.tanh(T17) S19 = fd.define_scalar(1.00000, dtype=DataType.Double) T20 = fd.ops.add(S19, T18) T21 = fd.ops.mul(T12, T20) T22 = fd.ops.abs(T21) T23 = fd.ops.max(T22, dims=[0, 1], keepdim=False, dtype=DataType.Null) T24 = fd.ops.cast(T23, dtype=DataType.Double) T25 = fd.ops.ne(T24, T24) S26 = fd.define_scalar(1.00000e-12, dtype=DataType.Double) T27 = fd.ops.gt(T24, S26) S28 = fd.define_scalar(1.00000e-12, dtype=DataType.Double) T29 = fd.ops.where(T27, T24, S28) T30 = fd.ops.where(T25, T24, T29) S31 = fd.define_scalar(448.000, dtype=DataType.Double) T32 = fd.ops.reciprocal(T30) T33 = fd.ops.mul(S31, T32) T34 = fd.ops.cast(T33, dtype=DataType.Float) T38 = fd.ops.broadcast_in_dim(T34, shape=[16, 64], broadcast_dims=[]) T39 = fd.ops.mul(T21, T38) T40 = fd.ops.ne(T39, T39) S41 = fd.define_scalar(-448.000, dtype=DataType.Double) T42 = fd.ops.gt(T39, S41) S43 = fd.define_scalar(-448.000, dtype=DataType.Double) T44 = fd.ops.where(T42, T39, S43) T45 = fd.ops.where(T40, T39, T44) T46 = fd.ops.ne(T45, T45) S47 = fd.define_scalar(448.000, dtype=DataType.Double) T48 = fd.ops.lt(T45, S47) S49 = fd.define_scalar(448.000, dtype=DataType.Double) T50 = fd.ops.where(T48, T45, S49) T51 = fd.ops.where(T46, T45, T50) fd.add_output(T34) fd.add_output(T51) with FusionDefinition() as fd: nvfuser_fusion_id4(fd) inputs = [ torch.testing.make_tensor((64,), dtype=torch.bfloat16, device='cuda:0'), torch.testing.make_tensor((16, 64), dtype=torch.bfloat16, device='cuda:0'), ] fd.execute(inputs)
Traceback (most recent call last): File "/home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/__init__.py", line 317, in execute results = self._execute( RuntimeError: INTERNAL ASSERT FAILED at "/home/mkozuki/ghq/github.com/crcrpar/Fuser/csrc/device_lower/analysis/sync_information.cpp":827, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Inconsistent parallelization found between TV52 (T52_l___bfloat[ iS422{1}, iUS424{1}, ithreadIdx.x425{16}_p, iV421{4} ]) and TV2(T2_l___bfloat[ iS199{1}, iUS201{1}, ithreadIdx.x202{16}_p, iS198{4} ] ca_pos( 4 )). Producer is required to be in Global or Shared Memory based on parallelization strategy. RAW flags: (threadIdx.x) Exception raised from SyncMap at /home/mkozuki/ghq/github.com/crcrpar/Fuser/csrc/device_lower/analysis/sync_information.cpp:827 (most recent call first): frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xcb (0x72ef20c4811b in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x3b (0x72ef20c4837b in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #2: nvfuser::SyncMap::SyncMap(nvfuser::Fusion*) + 0x35bd (0x72ef20b0873d in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #3: <unknown function> + 0x33b20f (0x72ef20b3b20f in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #4: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0x1579 (0x72ef20b39039 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #5: nvfuser::KernelExecutor::compile(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::SchedulerType) + 0x7bf (0x72ef20f1abbf in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #6: <unknown function> + 0x72fc2f (0x72ef20f2fc2f in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #7: <unknown function> + 0x75fad2 (0x72ef20f5fad2 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #8: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x9ee (0x72ef20f5ea6e in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #9: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x179 (0x72ef20f51cb9 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #10: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >, std::vector<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >, std::allocator<std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > > >) const + 0xacc (0x72ef210cc5ec in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #11: <unknown function> + 0x13c688 (0x72ef2093c688 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #12: <unknown function> + 0x13ba47 (0x72ef2093ba47 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so) frame #13: <unknown function> + 0x1cb301 (0x72ef209cb301 in /home/mkozuki/ghq/github.com/crcrpar/Fuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
import torch import torch.nn as nn from torchao.float8 import convert_to_float8_training import thunder from thunder.tests.make_tensor import make_tensor def main(): batch_size, in_features, out_features = 16, 32, 64 device = torch.device("cuda") dtype = torch.bfloat16 bias = True model = nn.Sequential( nn.Linear(in_features, out_features, bias=bias), nn.GELU(approximate="tanh"), nn.Linear(out_features, out_features, bias=bias), ).to(device=device, dtype=dtype) fp8_model = convert_to_float8_training(model) x = make_tensor((batch_size, in_features), device=device, dtype=dtype) jitted = thunder.jit(fp8_model, executors=[thunder.get_executor("torch"), thunder.get_executor("nvfuser")]) actual = jitted(x) if __name__ == "__main__": main()
Just FYI, this script works if a model is just one nn.Linear(in_features, out_features, bias=bias)
nn.Linear(in_features, out_features, bias=bias)
@crcrpar could we get the Thunder command as well?
@kevinstephano I added the section of repro steps to the description
Steps to reproduce
Just FYI, this script works if a model is just one
nn.Linear(in_features, out_features, bias=bias)