NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
238 stars 43 forks source link

IdModel buildAllGraphs() fails in presegment pass #2253

Open jjsjann123 opened 1 month ago

jjsjann123 commented 1 month ago

Not sure if this is a real issue or just a mis-use of IdModel.

Here's the repro script vvv (likely it'll be rendered as obsolete after #2252)

Basically passing the fusion from this program below to IdModel construct triggers an assert during buildLoopGraph(). I can help getting a cpp test if this turns out to be a real issue. cc'ing @naoyam

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    S0 = fd.define_scalar(None, dtype=DataType.Int)
    T1 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, None, True, True], dtype=DataType.BFloat16, is_cpu=False, strid
e_order=[3, 2, 1, 0])
    T2 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, strid
e_order=[3, 2, 1, 0])
    T3 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T4 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, None, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T5 = fd.define_tensor(shape=[-1, -1, 1, -1, -1], contiguity=[True, None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[4, 1, 3, 2, 0])
    T6 = fd.ops.cast(T1, dtype=DataType.Float)
    T7 = fd.ops.cast(T4, dtype=DataType.Float)
    S8 = fd.define_scalar(0, dtype=DataType.Int)
    S9 = fd.define_scalar(32, dtype=DataType.Int)
    S10 = fd.define_scalar(32, dtype=DataType.Int)
    S11 = fd.define_scalar(4096, dtype=DataType.Int)
    S12 = fd.define_scalar(0, dtype=DataType.Int)
    V13 = fd.define_vector([S9, S10, S11, S12], dtype=DataType.Int)
    T14 = fd.ops.full(shape=V13, fill_value=S8, dtype=DataType.BFloat16)
    S15 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T16 = fd.ops.pad(T14, [0, 128, 0, 0, 0, 0, 0, 0], S15)
    S17 = fd.define_scalar(0, dtype=DataType.Int)
    S18 = fd.define_scalar(32, dtype=DataType.Int)
    S19 = fd.define_scalar(32, dtype=DataType.Int)
    S20 = fd.define_scalar(4096, dtype=DataType.Int)
    S21 = fd.define_scalar(0, dtype=DataType.Int)
    V22 = fd.define_vector([S18, S19, S20, S21], dtype=DataType.Int)
    T23 = fd.ops.full(shape=V22, fill_value=S17, dtype=DataType.BFloat16)
    S24 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T25 = fd.ops.pad(T23, [0, 128, 0, 0, 0, 0, 0, 0], S24)
    T26 = fd.ops.cast(T2, dtype=DataType.Float)
    T27 = fd.ops.mul(T7, T26)
    T28 = fd.ops.cast(T27, dtype=DataType.BFloat16)
    T29 = fd.ops.mul(T6, T26)
    T30 = fd.ops.slice(T28, start_indices=[0, 0, 0, 0], end_indices=[32, 32, 4096, 64], strides=[1, 1, 1, 1])
    T31 = fd.ops.slice(T28, start_indices=[0, 0, 0, 64], end_indices=[32, 32, 4096, 128], strides=[1, 1, 1, 1])
    T32 = fd.ops.cast(T30, dtype=DataType.Float)
    T33 = fd.ops.neg(T32)
    T34 = fd.ops.cast(T33, dtype=DataType.BFloat16)
    S35 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T36 = fd.ops.pad(T34, [64, 0, 0, 0, 0, 0, 0, 0], S35)
    T37 = fd.ops.cast(T36, dtype=DataType.Float)
    T38 = fd.ops.add(T29, T37)
    S39 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T40 = fd.ops.pad(T31, [0, 64, 0, 0, 0, 0, 0, 0], S39)
    T41 = fd.ops.cast(T40, dtype=DataType.Float)
    T42 = fd.ops.add(T38, T41)
    T43 = fd.ops.cast(T42, dtype=DataType.BFloat16)
    S44 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T45 = fd.ops.pad(T43, [0, 0, 0, 0, 0, 0, 0, 0], S44)
    T46 = fd.ops.cast(T16, dtype=DataType.Float)
    T47 = fd.ops.cast(T45, dtype=DataType.Float)
    T48 = fd.ops.add(T46, T47)
    T49 = fd.ops.cast(T48, dtype=DataType.BFloat16)
    T50 = fd.ops.cast(T3, dtype=DataType.Float)
    T51 = fd.ops.mul(T7, T50)
    T52 = fd.ops.cast(T51, dtype=DataType.BFloat16)
    T53 = fd.ops.mul(T6, T50)
    T54 = fd.ops.slice(T52, start_indices=[0, 0, 0, 0], end_indices=[32, 32, 4096, 64], strides=[1, 1, 1, 1])
    T55 = fd.ops.slice(T52, start_indices=[0, 0, 0, 64], end_indices=[32, 32, 4096, 128], strides=[1, 1, 1, 1])
    T56 = fd.ops.cast(T54, dtype=DataType.Float)
    T57 = fd.ops.neg(T56)
    T58 = fd.ops.cast(T57, dtype=DataType.BFloat16)
    S59 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T60 = fd.ops.pad(T58, [64, 0, 0, 0, 0, 0, 0, 0], S59)
    T61 = fd.ops.cast(T60, dtype=DataType.Float)
    T62 = fd.ops.add(T53, T61)
    S63 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T64 = fd.ops.pad(T55, [0, 64, 0, 0, 0, 0, 0, 0], S63)
    T65 = fd.ops.cast(T64, dtype=DataType.Float)
    T66 = fd.ops.add(T62, T65)
    T67 = fd.ops.cast(T66, dtype=DataType.BFloat16)
    S68 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T69 = fd.ops.pad(T67, [0, 0, 0, 0, 0, 0, 0, 0], S68)
    T70 = fd.ops.cast(T25, dtype=DataType.Float)
    T71 = fd.ops.cast(T69, dtype=DataType.Float)
    T72 = fd.ops.add(T70, T71)
    T73 = fd.ops.cast(T72, dtype=DataType.BFloat16)
    S74 = fd.define_scalar(32, dtype=DataType.Int)
    S75 = fd.define_scalar(32, dtype=DataType.Int)
    S76 = fd.define_scalar(1, dtype=DataType.Int)
    S77 = fd.define_scalar(4096, dtype=DataType.Int)
    S78 = fd.define_scalar(128, dtype=DataType.Int)
    V79 = fd.define_vector([S74, S75, S76, S77, S78], dtype=DataType.Int)
    T80 = fd.ops.reshape(T49, new_shape=V79)
    S81 = fd.define_scalar(32, dtype=DataType.Int)
    S82 = fd.define_scalar(32, dtype=DataType.Int)
    S83 = fd.define_scalar(1, dtype=DataType.Int)
    S84 = fd.define_scalar(4096, dtype=DataType.Int)
    S85 = fd.define_scalar(128, dtype=DataType.Int)
    V86 = fd.define_vector([S81, S82, S83, S84, S85], dtype=DataType.Int)
    T87 = fd.ops.reshape(T73, new_shape=V86)
    T88 = fd.ops.cat([T87, T80, T5], dim=2)
    fd.add_output(T88)

with FusionDefinition() as fd:
    nvfuser_fusion_id1(fd)

inputs = [
    2,
    torch.randn((524288,), dtype=torch.bfloat16, device='cuda:0').as_strided((32, 32, 4096, 128), (0, 0, 128, 1)),
    torch.randn((536870912,), dtype=torch.bfloat16, device='cuda:0').as_strided((32, 32, 4096, 128), (16777216, 524288, 128, 1)),
    torch.randn((536870912,), dtype=torch.bfloat16, device='cuda:0').as_strided((32, 32, 4096, 128), (16777216, 524288, 128, 1)),
    torch.randn((524288,), dtype=torch.bfloat16, device='cuda:0').as_strided((32, 32, 4096, 128), (0, 0, 128, 1)),
    torch.randn((536870912,), dtype=torch.bfloat16, device='cuda:0').as_strided((32, 32, 1, 4096, 128), (16777216, 128, 16777216, 4096, 1)),
]
fd.execute(inputs)

Backtrace

#0  0x00007ffdb754ec80 in __gnu_cxx::__normal_iterator<nvfuser::Val* const*, std::vector<nvfuser::Val*, std::allocator<nvfuser::Val
*> > >::__normal_iterator (this=0x7fffffffb250, __i=<error reading variable: Cannot access memory at address 0x50>)
    at /usr/include/c++/11/bits/stl_iterator.h:1028
#1  0x00007ffdb765739c in std::vector<nvfuser::Val*, std::allocator<nvfuser::Val*> >::cend (this=0x48)
    at /usr/include/c++/11/bits/stl_vector.h:894
#2  0x00007ffdb767515f in nvfuser::ir_utils::filterByType<nvfuser::IterDomain, std::vector<nvfuser::Val*, std::allocator<nvfuser::Val*> > > (inputs=<error reading variable: Cannot access memory at address 0x50>) at /opt/pytorch/nvfuser/csrc/ir/utils.h:215
#3  0x00007ffdb7aa4744 in nvfuser::IdModel::addReplayAs (this=0x7fffffffbd70,
    new_inputs=std::vector of length 1, capacity 1 = {...}, expr=0x7ffd72873070)
    at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:1267
#4  0x00007ffdb7aa38b0 in nvfuser::IdModel::propagatePromotionsInIELGraph (this=0x7fffffffbd70, iel_graph=...,
    iel_promotion_map=std::unordered_map with 0 elements, loop_graph=...,
    loop_graph_promotion_map=std::unordered_map with 258 elements = {...}) at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:1157
#5  0x00007ffdb7aa164c in nvfuser::IdModel::buildLoopPromotionMap (this=0x7fffffffbd70, inlining_info=...)
    at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:649
#6  0x00007ffdb7aa144b in nvfuser::IdModel::buildLoopGraph (this=0x7fffffffbd70)
    at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:576
#7  0x00007ffdb7aa2547 in nvfuser::IdModel::buildAllGraphs (this=0x7fffffffbd70)
    at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:885
#8  0x00007ffdb7a9d846 in nvfuser::IdModel::IdModel (this=0x7fffffffbd70, fusion=0x7ffda7978400, build_graphs=true,
    allow_self_mapping=true, validate=true) at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:134
#9  0x00007ffdb7d77989 in nvfuser::preseg_passes::inferenceAllocationOrder (fusion=0x7ffda7978400,
    srcs=std::vector of length 5, capacity 5 = {...}, dsts=std::vector of length 1, capacity 1 = {...})
    at /opt/pytorch/nvfuser/csrc/preseg_passes/allocation_order_inference.cpp:218
#10 0x00007ffdb7d780c5 in nvfuser::preseg_passes::AllocationDomainPass::runPass (fusion=0x7ffda7978400)
    at /opt/pytorch/nvfuser/csrc/preseg_passes/allocation_order_inference.cpp:304
#11 0x00007ffdb7d87575 in nvfuser::preseg_passes::OptimizationPass<nvfuser::preseg_passes::AllocationDomainPass>::runPass (
    fusion=0x7ffda7978400) at /opt/pytorch/nvfuser/csrc/preseg_passes/optimization_pass.h:53
#12 0x00007ffdb7d864d0 in nvfuser::preseg_passes::PreSegmenter::runPass (fusion=0x7ffda7978400)
    at /opt/pytorch/nvfuser/csrc/preseg_passes/pre_segmenter.cpp:32
#13 0x00007ffdb7c44e0b in nvfuser::preseg_passes::OptimizationPass<nvfuser::preseg_passes::PreSegmenter>::runPass (
    fusion=0x7ffda7978400) at /opt/pytorch/nvfuser/csrc/preseg_passes/optimization_pass.h:53
#14 0x00007ffdb7c33272 in nvfuser::FusionKernelRuntime::FusionKernelRuntime (this=0x7ffd728cd700,
    fusion=std::unique_ptr<nvfuser::Fusion> = {...}, args=..., serde_buffer=0x0,
naoyam commented 1 month ago

Can you please run the fusion without the preseg pass but NVFUSER_ENABLE=id_model? The lowering will construct an IdModel, so it it also fails, then if it's preseg or not should matter.

jjsjann123 commented 1 month ago

It's indeed failing. Running on top of #2252 (where preseg is no longer failing). You can see the added assert during GpuLower.

root@124a06e5d7bc:/volume# NVFUSER_DISABLE=parallel_compile NVFUSER_ENABLE=id_model python repro_nvfuser.py
Traceback (most recent call last):
  File "/volume/repro_nvfuser.py", line 107, in <module>
    fd.execute(inputs)
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 200, in execute
    result = self._execute(
RuntimeError: replay != nullptr INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/id_model/id_model.cpp":1266, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. no replay found
Exception raised from addReplayAs at /opt/pytorch/nvfuser/csrc/id_model/id_model.cpp:1266 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x7f (0x7fce7ab6bb92 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x5d (0x7fce7ab6bdd3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: <unknown function> + 0x8883ca (0x7fce7ace33ca in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x8874fe (0x7fce7ace24fe in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: <unknown function> + 0x88529a (0x7fce7ace029a in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: <unknown function> + 0x885099 (0x7fce7ace0099 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: <unknown function> + 0x886195 (0x7fce7ace1195 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x881494 (0x7fce7acdc494 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x5519c2 (0x7fce7a9ac9c2 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: nvfuser::GpuLower::GpuLower(nvfuser::Fusion*, nvfuser::CompileParams const&) + 0x5ee (0x7fce7a9abbcc in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: <unknown function> + 0x73e529 (0x7fce7ab99529 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::ScheduleHeuristic, long, long, long, long) + 0x638 (0x7fce7ab792d4 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: <unknown function> + 0xa1992d (0x7fce7ae7492d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #13: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x445 (0x7fce7ae740ab in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #14: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0x4e2 (0x7fce7ae6ec80 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #15: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, bool, bool, std::optional<signed char>) const + 0x515 (0x7fce7b22a879 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #16: <unknown function> + 0x1a7700 (0x7fce7a602700 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #17: <unknown function> + 0x2ac89b (0x7fce7a70789b in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #18: <unknown function> + 0x29f95d (0x7fce7a6fa95d in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #19: <unknown function> + 0x24f72e (0x7fce7a6aa72e in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #20: <unknown function> + 0x24f800 (0x7fce7a6aa800 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #21: <unknown function> + 0x2d7c15 (0x7fce7a732c15 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #37: <unknown function> + 0x29d90 (0x7fd0bb4b5d90 in /lib/x86_64-linux-gnu/libc.so.6)
frame #38: __libc_start_main + 0x80 (0x7fd0bb4b5e40 in /lib/x86_64-linux-gnu/libc.so.6)
naoyam commented 1 month ago

Thanks for checking!

IvanYashchuk commented 1 month ago

@jjsjann123, I'm just curious if the fusion in the issue description was generated from running Thunder's microbenchmark test_llama2_qkv_split_rope_7b_train? I'm seeing the same failure but line number is different now:

replay != nullptr INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/id_model/id_model.cpp":786
jjsjann123 commented 1 month ago

@jjsjann123, I'm just curious if the fusion in the issue description was generated from running Thunder's microbenchmark test_llama2_qkv_split_rope_7b_train? I'm seeing the same failure but line number is different now:

replay != nullptr INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/id_model/id_model.cpp":786

Oh no.... you shouldn't be running into this issue. (i.e. the issue still stands but it shouldn't pop up in codegen any more after #2252 ). But it's possible that even building EXACT graph triggers this issue.

Can you give me a repro command for running that benchmark?

IvanYashchuk commented 1 month ago

I also see this problem running regular e2e benchmark now (current commit 8baa5505b247311a63adcca6e7fa2138929c8650):

python thunder/benchmarks/benchmark_litgpt.py --compile=thunder --micro_batch_size=1 --model_name=Llama-2-7b-hf --n_layers=1
jjsjann123 commented 1 month ago

The pattern from the benchmark looks similar to what we have in this repro. I think it's an accidental change here: https://github.com/NVIDIA/Fuser/pull/2298#discussion_r1616371356