NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

LoadStoreOp::toInlineString should not assume it's a tensor op #3360

Closed naoyam closed 1 week ago

naoyam commented 2 weeks ago

https://github.com/NVIDIA/Fuser/blob/main/csrc/ir/nodes.cpp#L2275-L2277

set is allowed to have a scalar, so LoadStoreOp should not assume its input and output are tensors. This command, for example, fails:

NVFUSER_DUMP=fusion_ir_math pytest -s -v tests/python/test_python_frontend.py -k test_pad_dynamic

Traceback (most recent call last):
  File "/raid/nmaruyama/debug1/nvfuser/__init__.py", line 182, in execute
    results = self._execute(
RuntimeError: Tensor op can not be printed inline
Exception raised from toInlineString at /raid/nmaruyama/debug1/csrc/ir/nodes.cpp:2276 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const*, char const*, unsigned int, char const*) + 0x92 (0x7fa194858642 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #1: <unknown function> + 0x10bb441 (0x7fa194a54441 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #2: nvfuser::Val::toInlineString[abi:cxx11](int) const + 0x9e (0x7fa194a0845e in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #3: nvfuser::UnaryOp::toInlineString[abi:cxx11](int) const + 0x74 (0x7fa194a3ca04 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #4: nvfuser::Val::toInlineString[abi:cxx11](int) const + 0x9e (0x7fa194a0845e in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #5: nvfuser::BinaryOp::toInlineString[abi:cxx11](int) const + 0xb3 (0x7fa194a40243 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #6: nvfuser::Val::toInlineString[abi:cxx11](int) const + 0x9e (0x7fa194a0845e in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #7: nvfuser::BinaryOp::toInlineString[abi:cxx11](int) const + 0x7d (0x7fa194a4020d in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #8: nvfuser::Val::toInlineString[abi:cxx11](int) const + 0x9e (0x7fa194a0845e in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #9: nvfuser::BinaryOp::toInlineString[abi:cxx11](int) const + 0x7d (0x7fa194a4020d in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #10: nvfuser::Val::toInlineString[abi:cxx11](int) const + 0x9e (0x7fa194a0845e in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #11: nvfuser::BinaryOp::toInlineString[abi:cxx11](int) const + 0xb3 (0x7fa194a40243 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #12: nvfuser::Val::toInlineString[abi:cxx11](int) const + 0x9e (0x7fa194a0845e in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #13: nvfuser::BinaryOp::toInlineString[abi:cxx11](int) const + 0x7d (0x7fa194a4020d in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #14: nvfuser::Val::toInlineString[abi:cxx11](int) const + 0x9e (0x7fa194a0845e in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #15: nvfuser::BinaryOp::toInlineString[abi:cxx11](int) const + 0x7d (0x7fa194a4020d in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #16: nvfuser::Val::toInlineString[abi:cxx11](int) const + 0x9e (0x7fa194a0845e in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #17: nvfuser::IterDomain::toString[abi:cxx11](int) const + 0x347 (0x7fa194a55a37 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #18: <unknown function> + 0xb2f106 (0x7fa1944c8106 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #19: <unknown function> + 0xb2f02a (0x7fa1944c802a in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #20: <unknown function> + 0xb226c6 (0x7fa1944bb6c6 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #21: <unknown function> + 0x10c209d (0x7fa194a5b09d in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #22: <unknown function> + 0x10c272b (0x7fa194a5b72b in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #23: nvfuser::TensorView::toString[abi:cxx11](int) const + 0x259 (0x7fa194ec6d19 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #24: nvfuser::operator<<(std::ostream&, nvfuser::Statement const*) + 0x53 (0x7fa194a31e93 in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
frame #25: nvfuser::Fusion::printMath(bool) + 0x2ae (0x7fa1948db3ce in /raid/nmaruyama/debug1/nvfuser/_C.cpython-310-x86_64-linux-gnu.so)
kevinstephano commented 2 weeks ago

It looks like this was found related to the pad operation.

liqiangxl commented 1 week ago

I implemented as only print out out() and we will get something like:

T1_g_float[ iblockIdx.x18{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( i13 )) ) ) + ( (nvfuser_index_t)(( i16 )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( i19 )) ) ) + ( (nvfuser_index_t)(( i22 )) ) ) ) ), 128) ), 1) )}, iUS19{1}, ithreadIdx.x17{128} ] ca_pos( 2 ) produce_pos( 3 )
   = Set( T2_l_float[ iblockIdx.x24{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( i13 )) ) ) + ( (nvfuser_index_t)(( i16 )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( i19 )) ) ) + ( (nvfuser_index_t)(( i22 )) ) ) ) ), 128) ), 1) )}, iUS25{1}, ithreadIdx.x23{128} ] ca_pos( 3 ), cache_op=Streaming )
} // %kernel_math

If we want to inline everything, e.g. print all the set, we will get something looks tooo verbose:

T1_g_float[ iblockIdx.x18{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( i13
           = Set( i11 )) ) ) + ( (nvfuser_index_t)(( i16
           = Set( i11 )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( i19
           = Set( i11 )) ) ) + ( (nvfuser_index_t)(( i22
           = Set( i11 )) ) ) ) ), 128) ), 1) )}, iUS19{1}, ithreadIdx.x17{128} ] ca_pos( 2 ) produce_pos( 3 )
   = Set( T2_l_float[ iblockIdx.x24{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( i13
           = Set( i11 )) ) ) + ( (nvfuser_index_t)(( i16
           = Set( i11 )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( i19
           = Set( i11 )) ) ) + ( (nvfuser_index_t)(( i22
           = Set( i11 )) ) ) ) ), 128) ), 1) )}, iUS25{1}, ithreadIdx.x23{128} ] ca_pos( 3 ), cache_op=Streaming )
naoyam commented 1 week ago

Does this mean we have:

i13 = Set( i11);

If so, shouldn't we print i11 instead of i13? When we have, for example, ops like:

i1 = i2 + i3

Inlined printing would be i2 + i3 instead of i1.

liqiangxl commented 1 week ago

print the input is also fine. I selected output becuase the fusion should have the definition of the original set. For example, the fusion has

i13
= Set( i11, cache_op=Streaming )

The whole fusion:

%kernel_math {
f6 = (float)(7);
f9 = float(2.5) * f6;
i11 = (int64_t)(f9);
i13
   = Set( i11, cache_op=Streaming )
i29 = (nvfuser_index_t)(i13);
i16
   = Set( i11, cache_op=Streaming )
i31 = (nvfuser_index_t)(i16);
i19
   = Set( i11, cache_op=Streaming )
i33 = (nvfuser_index_t)(i19);
i22
   = Set( i11, cache_op=Streaming )
i35 = (nvfuser_index_t)(i22);
T2_l_float[ iblockIdx.x24{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( i13 )) ) ) + ( (nvfuser_index_t)(( i16 )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( i19 )) ) ) + ( (nvfuser_index_t)(( i22 )) ) ) ) ), 128) ), 1) )}, iUS25{1}, ithreadIdx.x23{128} ] ca_pos( 3 )
   = pad( T0_g_float[ iS30{( ceilDiv(( ceilDiv(( 1 * ( i1 * i2 ) ), 128) ), 1) )}, iS31{1}, iS29{128} ], {0, 0, i29, i31, i33, i35} )
T1_g_float[ iblockIdx.x18{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( i13 )) ) ) + ( (nvfuser_index_t)(( i16 )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( i19 )) ) ) + ( (nvfuser_index_t)(( i22 )) ) ) ) ), 128) ), 1) )}, iUS19{1}, ithreadIdx.x17{128} ] ca_pos( 2 ) produce_pos( 3 )
   = Set( T2_l_float[ iblockIdx.x24{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( i13 )) ) ) + ( (nvfuser_index_t)(( i16 )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( i19 )) ) ) + ( (nvfuser_index_t)(( i22 )) ) ) ) ), 128) ), 1) )}, iUS25{1}, ithreadIdx.x23{128} ] ca_pos( 3 ), cache_op=Streaming )
} // %kernel_math 

i1 = i2 + i3 is not a set, it is a binary add.

naoyam commented 1 week ago

For inlined printing, I don't think there's any reason to treat i1 = i2 and i1 = i2 + i3 differently. For a scalar Val, toInlineString() returns a string that only consists of symbols that have no defining expressions. I think that would make sense for a scalar val defined with set too.

liqiangxl commented 1 week ago

Sure, we can do that. It give us an output like

T1_g_float[ iblockIdx.x18{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ) )) ) ) + ( (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ) )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ) )) ) ) + ( (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ) )) ) ) ) ), 128) ), 1) )}, iUS19{1}, ithreadIdx.x17{128} ] ca_pos( 2 ) produce_pos( 3 )
   = Set( T2_l_float[ iblockIdx.x24{( ceilDiv(( ceilDiv(( 1 * ( ( ( i1 + ( (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ) )) ) ) + ( (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ) )) ) ) * ( ( i2 + ( (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ) )) ) ) + ( (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ) )) ) ) ) ), 128) ), 1) )}, iUS25{1}, ithreadIdx.x23{128} ] ca_pos( 3 ), cache_op=Streaming )
liqiangxl commented 1 week ago

i13 is inline printed as (nvfuser_index_t)(( ( (int64_t)(( float(2.5) * ( (float)(7) ) )) ), its definition is

f6 = (float)(7);
f9 = float(2.5) * f6;
i11 = (int64_t)(f9);
i13
   = Set( i11, cache_op=Streaming )