NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
245 stars 48 forks source link

Codegen error: `have dynamic allocations but are placed in local memory.` coming from qkv_split_rope backwards #2702

Closed jjsjann123 closed 1 week ago

jjsjann123 commented 1 month ago

Seems to be a functional regression. A new failure show up from disabling bookend in the qkv_split_rope backward.

Error message:

Traceback (most recent call last):
  File "/opt/pytorch/nvfuser/nvfuser/__init__.py", line 210, in execute
    result = self._execute(
RuntimeError: false INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/executor.cpp":492, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Allocations must be based on constant integers for local memory. However, found: T74_l[ iblockIdx.x515{( ceilDiv(( ceilDiv(( ceilDiv(( 8 * ( 8192 * 128 ) ), 8) ), blockDim.x) ), 1) )}, ithreadIdx.x514{blockDim.x}, iUS516{1}, iV512{8}, iS519{( ceilDiv(( ceilDiv(( ceilDiv(( (( (( getMetaData(T4) )).logical_size ))[1] ), 8) ), 2) ), 1) )}, iUS520{1}, iUR518{2}, bS420{1} ] ca_pos( 3 ), T74_l[ iblockIdx.x515{( ceilDiv(( ceilDiv(( ceilDiv(( 8 * ( 8192 * 128 ) ), 8) ), blockDim.x) ), 1) )}, ithreadIdx.x514{blockDim.x}, iUS516{1}, iV512{8}, iS519{( ceilDiv(( ceilDiv(( ceilDiv(( (( (( getMetaData(T4) )).logical_size ))[1] ), 8) ), 2) ), 1) )}, iUS520{1}, iUR518{2}, bS420{1} ] ca_pos( 3 ),  have dynamic allocations but are placed in local memory.
Exception raised from compileFusion at /opt/pytorch/nvfuser/csrc/executor.cpp:492 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0xf3 (0x7f3e15150829 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: nvfuser::nvfErrorFail(char const*, char const*, unsigned int, char const*, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&) + 0x53 (0x7f3e154ad263 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::FusionExecutor::compileFusion(nvfuser::Fusion*, nvfuser::KernelArgumentHolder const&, nvfuser::LaunchParams const&, nvfuser::CompileParams, nvfuser::ScheduleHeuristic, long, long, long, long) + 0x1fe7 (0x7f3e154c7257 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: <unknown function> + 0x6ab5d5 (0x7f3e156cc5d5 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::FusionKernelRuntime::compileFusionParallel(nvfuser::KernelArgumentHolder) + 0x492 (0x7f3e156d4622 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::FusionExecutorCache::runFusionWithInputs(c10::ArrayRef<c10::IValue> const&, std::optional<nvfuser::PrimDataType>, std::optional<signed char>) + 0xab3 (0x7f3e156df4a3 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: nvfuser::python_frontend::FusionDefinition::execute(c10::ArrayRef<c10::IValue> const&, std::optional<signed char>, bool, bool, bool) const + 0x394 (0x7f3e158d0a24 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: <unknown function> + 0x1ad1be (0x7f3e151ce1be in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: <unknown function> + 0x22452f (0x7f3e1524552f in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: <unknown function> + 0x2ba5f0 (0x7f3e152db5f0 in /opt/pytorch/nvfuser/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
<omitting python frames>
frame #25: <unknown function> + 0x29d90 (0x7f4057536d90 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #26: __libc_start_main + 0x80 (0x7f4057536e40 in /usr/lib/x86_64-linux-gnu/libc.so.6)

vvv

import torch
from nvfuser import FusionDefinition, DataType

def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
    T0 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
    T2 = fd.define_tensor(shape=[1, -1, -1, -1], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T3 = fd.define_tensor(shape=[1, -1, -1, -1], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    T4 = fd.define_tensor(shape=[1, -1, -1, -1], contiguity=[None, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
    S5 = fd.define_scalar(1, dtype=DataType.Int)
    S6 = fd.define_scalar(32, dtype=DataType.Int)
    S7 = fd.define_scalar(8192, dtype=DataType.Int)
    S8 = fd.define_scalar(128, dtype=DataType.Int)
    V9 = fd.define_vector([S5, S6, S7, S8], dtype=DataType.Int)
    T10 = fd.ops.broadcast_in_dim(T0, shape=V9, broadcast_dims=[2, 3])
    T11 = fd.ops.cast(T10, dtype=DataType.Float)
    S12 = fd.define_scalar(1, dtype=DataType.Int)
    S13 = fd.define_scalar(32, dtype=DataType.Int)
    S14 = fd.define_scalar(8192, dtype=DataType.Int)
    S15 = fd.define_scalar(128, dtype=DataType.Int)
    V16 = fd.define_vector([S12, S13, S14, S15], dtype=DataType.Int)
    T17 = fd.ops.broadcast_in_dim(T1, shape=V16, broadcast_dims=[2, 3])
    T18 = fd.ops.cast(T17, dtype=DataType.Float)
    T19 = fd.ops.slice(T3, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 8192, 128], strides=[1, 1, 1, 1])
    S20 = fd.define_scalar(0, dtype=DataType.Int)
    S21 = fd.define_scalar(1, dtype=DataType.Int)
    S22 = fd.define_scalar(32, dtype=DataType.Int)
    S23 = fd.define_scalar(8192, dtype=DataType.Int)
    S24 = fd.define_scalar(0, dtype=DataType.Int)
    V25 = fd.define_vector([S21, S22, S23, S24], dtype=DataType.Int)
    T26 = fd.ops.full(shape=V25, fill_value=S20, dtype=DataType.BFloat16)
    S27 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T28 = fd.ops.pad(T26, [0, 128, 0, 0, 0, 0, 0, 0], S27)
    T29 = fd.ops.slice(T2, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 8192, 128], strides=[1, 1, 1, 1])
    T30 = fd.ops.cast(T19, dtype=DataType.Float)
    T31 = fd.ops.mul(T18, T30)
    T32 = fd.ops.cast(T31, dtype=DataType.BFloat16)
    T33 = fd.ops.mul(T11, T30)
    T34 = fd.ops.cast(T28, dtype=DataType.Float)
    T35 = fd.ops.add(T34, T33)
    T36 = fd.ops.slice(T32, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 8192, 64], strides=[1, 1, 1, 1])
    T37 = fd.ops.slice(T32, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 8192, 128], strides=[1, 1, 1, 1])
    T38 = fd.ops.cast(T36, dtype=DataType.Float)
    T39 = fd.ops.neg(T38)
    T40 = fd.ops.cast(T39, dtype=DataType.BFloat16)
    S41 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T42 = fd.ops.pad(T40, [64, 0, 0, 0, 0, 0, 0, 0], S41)
    T43 = fd.ops.cast(T42, dtype=DataType.Float)
    T44 = fd.ops.add(T35, T43)
    S45 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T46 = fd.ops.pad(T37, [0, 64, 0, 0, 0, 0, 0, 0], S45)
    T47 = fd.ops.cast(T46, dtype=DataType.Float)
    T48 = fd.ops.add(T44, T47)
    T49 = fd.ops.cast(T48, dtype=DataType.BFloat16)
    T50 = fd.ops.cast(T29, dtype=DataType.Float)
    T51 = fd.ops.mul(T18, T50)
    T52 = fd.ops.cast(T51, dtype=DataType.BFloat16)
    T53 = fd.ops.mul(T11, T50)
    T54 = fd.ops.add(T34, T53)
    T55 = fd.ops.slice(T52, start_indices=[0, 0, 0, 0], end_indices=[1, 32, 8192, 64], strides=[1, 1, 1, 1])
    T56 = fd.ops.slice(T52, start_indices=[0, 0, 0, 64], end_indices=[1, 32, 8192, 128], strides=[1, 1, 1, 1])
    T57 = fd.ops.cast(T55, dtype=DataType.Float)
    T58 = fd.ops.neg(T57)
    T59 = fd.ops.cast(T58, dtype=DataType.BFloat16)
    S60 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T61 = fd.ops.pad(T59, [64, 0, 0, 0, 0, 0, 0, 0], S60)
    T62 = fd.ops.cast(T61, dtype=DataType.Float)
    T63 = fd.ops.add(T54, T62)
    S64 = fd.define_scalar(0.00000, dtype=DataType.Double)
    T65 = fd.ops.pad(T56, [0, 64, 0, 0, 0, 0, 0, 0], S64)
    T66 = fd.ops.cast(T65, dtype=DataType.Float)
    T67 = fd.ops.add(T63, T66)
    T68 = fd.ops.cast(T67, dtype=DataType.BFloat16)
    S69 = fd.define_scalar(1, dtype=DataType.Int)
    S70 = fd.define_scalar(8, dtype=DataType.Int)
    S71 = fd.define_scalar(4, dtype=DataType.Int)
    S72 = fd.define_scalar(8192, dtype=DataType.Int)
    S73 = fd.define_scalar(128, dtype=DataType.Int)
    V74 = fd.define_vector([S69, S70, S71, S72, S73], dtype=DataType.Int)
    T75 = fd.ops.reshape(T4, new_shape=V74)
    S76 = fd.define_scalar(1, dtype=DataType.Int)
    S77 = fd.define_scalar(8, dtype=DataType.Int)
    S78 = fd.define_scalar(4, dtype=DataType.Int)
    S79 = fd.define_scalar(8192, dtype=DataType.Int)
    S80 = fd.define_scalar(128, dtype=DataType.Int)
    V81 = fd.define_vector([S76, S77, S78, S79, S80], dtype=DataType.Int)
    T82 = fd.ops.reshape(T49, new_shape=V81)
    S83 = fd.define_scalar(1, dtype=DataType.Int)
    S84 = fd.define_scalar(8, dtype=DataType.Int)
    S85 = fd.define_scalar(4, dtype=DataType.Int)
    S86 = fd.define_scalar(8192, dtype=DataType.Int)
    S87 = fd.define_scalar(128, dtype=DataType.Int)
    V88 = fd.define_vector([S83, S84, S85, S86, S87], dtype=DataType.Int)
    T89 = fd.ops.reshape(T68, new_shape=V88)
    T90 = fd.ops.cast(T75, dtype=DataType.Float)
    T91 = fd.ops.sum(T90, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T92 = fd.ops.cast(T91, dtype=DataType.BFloat16)
    S93 = fd.define_scalar(1, dtype=DataType.Int)
    S94 = fd.define_scalar(8, dtype=DataType.Int)
    S95 = fd.define_scalar(1, dtype=DataType.Int)
    S96 = fd.define_scalar(8192, dtype=DataType.Int)
    S97 = fd.define_scalar(128, dtype=DataType.Int)
    V98 = fd.define_vector([S93, S94, S95, S96, S97], dtype=DataType.Int)
    T99 = fd.ops.broadcast_in_dim(T92, shape=V98, broadcast_dims=[1, 3, 4])
    T100 = fd.ops.cast(T82, dtype=DataType.Float)
    T101 = fd.ops.sum(T100, dims=[0, 2], keepdim=False, dtype=DataType.Null)
    T102 = fd.ops.cast(T101, dtype=DataType.BFloat16)
    S103 = fd.define_scalar(1, dtype=DataType.Int)
    S104 = fd.define_scalar(8, dtype=DataType.Int)
    S105 = fd.define_scalar(1, dtype=DataType.Int)
    S106 = fd.define_scalar(8192, dtype=DataType.Int)
    S107 = fd.define_scalar(128, dtype=DataType.Int)
    V108 = fd.define_vector([S103, S104, S105, S106, S107], dtype=DataType.Int)
    T109 = fd.ops.broadcast_in_dim(T102, shape=V108, broadcast_dims=[1, 3, 4])
    T110 = fd.ops.cat([T89, T109, T99], dim=2)
    T111 = fd.ops.permute(T110, dims=[0, 3, 1, 2, 4])
    S112 = fd.define_scalar(1, dtype=DataType.Int)
    S113 = fd.define_scalar(8192, dtype=DataType.Int)
    S114 = fd.define_scalar(6144, dtype=DataType.Int)
    V115 = fd.define_vector([S112, S113, S114], dtype=DataType.Int)
    T116 = fd.ops.reshape(T111, new_shape=V115)
    fd.add_output(T116)

with FusionDefinition() as fd:
    nvfuser_fusion_id1(fd)

inputs = [
    torch.randn((1048576,), dtype=torch.bfloat16, device='cuda:0').as_strided((8192, 128), (128, 1)),
    torch.randn((1048576,), dtype=torch.bfloat16, device='cuda:0').as_strided((8192, 128), (128, 1)),
    torch.randn((33554432,), dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 8192, 128), (33554432, 1048576, 128, 1)),
    torch.randn((33554432,), dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 8192, 128), (33554432, 1048576, 128, 1)),
    torch.randn((33554432,), dtype=torch.bfloat16, device='cuda:0').as_strided((1, 32, 8192, 128), (33554432, 1048576, 128, 1)),
]
fd.execute(inputs)
jjsjann123 commented 1 month ago

cc'ing @wujingyue :sob:

wujingyue commented 1 month ago

Apparently, the bug predated https://github.com/NVIDIA/Fuser/commit/fe34321ad0a36a6fe96a625b3dc81e1ed716ef1b. I'm running a git bisect...

Bisect hit 1c80008afd1719dc3e36ccafc0294e9cb73a8b70, which broke build. I'll try to git bisect skip commits between that and 0cc22b910b5a6302444850019ae1e4f92974bc83...

https://github.com/NVIDIA/Fuser/commit/2c3c08fa994286018aa3a53aefc201e49fb26593 is the first commit that broke the reproducer. cc @liqiangxl

liqiangxl commented 1 month ago

Seems due to bug in sortAndRFactor, the correct order should be

reduction_rf_tv: T76_l[ ..., rS459{( ceilDiv(( ceilDiv(4, 2) ), 1) )}rf, iUS456{1}, iV452{8}, iUS460{1}rf, rUR458{2}rf ]

instead of

reduction_rf_tv: T76_l[ ..., iUS456{1}, iV452{8}, rS459{( ceilDiv(( ceilDiv(4, 2) ), 1) )}rf, iUS460{1}rf, rUR458{2}rf ]

During lowering, ceilDiv(4, 2) is replaced with ceilDiv(ceilDiv(T4.logical_size[1LL], 8), 2);

liqiangxl commented 1 month ago

ceilDiv(4, 2) also should be symbolic instead of const. Not sure why ceilDiv(i14, 8) is chagned to 4 in T53_g = __bfloat2float(T70_l)

Inputs:
  T4_g[ bS12{1}, iS13{i14}, iS412{8192}, iS413{128} ], __bfloat
Outputs:
  T57_g[ bS256{1}, iS394{8}, bS258{1}, iS395{8192}, iS396{128} ], __bfloat

%kernel_math {
T70_l[ bS325{1}, iS330{8}rf, iS331{( ceilDiv(i14, 8) )}rf, iS414{8192}, iS415{128} ] = view( T4_g[ bS12{1}, iS13{i14}, iS412{8192}, iS413{128} ] )
T53_g[ bS372{1}, iS373{8}, iS374{4}, iS375{8192}, iS376{128} ]
   = __bfloat2float(T70_l[ bS325{1}, iS330{8}rf, iS331{( ceilDiv(i14, 8) )}rf, iS414{8192}, iS415{128} ]);
T71_g[ iS377{8}, iS378{4}, iS379{8192}, iS380{128} ]
   = squeeze( T53_g[ bS372{1}, iS373{8}, iS374{4}, iS375{8192}, iS376{128} ] )
T72_l[ iS381{8}, rS382{4}, iS383{8192}, iS384{128} ]
   = reduction( T71_g[ iS377{8}, iS378{4}, iS379{8192}, iS380{128} ], op = add, initial value = float(0), allreduce = false )
T55_g[ iS388{8}, iS389{8192}, iS390{128} ]
   = __float2bfloat(T72_l[ iS381{8}, rS382{4}, iS383{8192}, iS384{128} ]);
T56_g[ bS251{1}, iS391{8}, bS253{1}, iS392{8192}, iS393{128} ]
   = broadcast( T55_g[ iS388{8}, iS389{8192}, iS390{128} ] )
T57_g[ bS256{1}, iS394{8}, bS258{1}, iS395{8192}, iS396{128} ]
   = Set( T56_g[ bS251{1}, iS391{8}, bS253{1}, iS392{8192}, iS393{128} ], cache_op=Streaming )
} // %kernel_math 
liqiangxl commented 1 month ago

In ExactLogicalDomainMap, we have

  { iS331{( ceilDiv(i14, 8) )}rf; iS374{4}; iS378{4}; rS382{4} }

We should be able to derive i14 = 32, so we don't need to replace i14 with T4.logical_size[1LL], this can also fix the bug. what do you think @jacobhinkle ?

jacobhinkle commented 1 month ago

Not sure why ceilDiv(i14, 8) is chagned to 4 in T53_g = __bfloat2float(T70_l)

I think it's the other way around: the 4 exists in the dynamic fusion before dynamic shape concretization:

T50_l[ ?S219{1}rf, ?S220{8}rf, ?S221{4}rf, ?S222{8192}rf, ?S223{128}rf ] = view( T4_g[ bS12{1}, iS13{i14}, iS14{i15}, iS15{i16} ] )                                                                                                                                 
T53_l[ ?S238{1}, ?S239{8}, ?S240{4}, ?S241{8192}, ?S242{128} ]                                                                    
   = __bfloat2float(T50_l[ ?S219{1}rf, ?S220{8}rf, ?S221{4}rf, ?S222{8192}rf, ?S223{128}rf ]);  

When we concretize T50_l's root->logical transforms, the ID ?S221{4}rf is replaced with the output of a Split, which has the ceilDiv extent.

wujingyue commented 1 month ago

Can someone summarize? Do we need both #2713 and #2714 or just #2714 which is being worked on? I guess the latter but wasn't sure.

jjsjann123 commented 1 month ago

Can someone summarize? Do we need both #2713 and #2714 or just #2714 which is being worked on? I guess the latter but wasn't sure.

You got it right. It's only #2714 that's needed. cc'ing @liqiangxl should we close #2713 ?

liqiangxl commented 1 month ago

2713 is fixing another issue (I think there is an issue of ordering loop domains), I'll disconnect it from this issue.