Closed wujingyue closed 3 months ago
FWIW, https://github.com/Lightning-AI/lightning-thunder/blob/bc3925be04e7ec58d9b24fb7ac55fbe007862a65/thunder/tests/opinfos.py#L6095-L6117 could be enhanced to test 3D. Currently, it only tests 1D and 2D input shapes.
FWIW, https://github.com/Lightning-AI/lightning-thunder/blob/bc3925be04e7ec58d9b24fb7ac55fbe007862a65/thunder/tests/opinfos.py#L6095-L6117 could be enhanced to test 3D. Currently, it only tests 1D and 2D input shapes.
Agreed, we do test for 3D cases in our tests though: https://github.com/NVIDIA/Fuser/blob/a873bffa702b14b54ca03ca5aea7944691982093/tests/python/pytest_input_generators.py#L1515-L1544.
We likely need more tests such as the reproducer of this issue and other larger fusions.
The issue seems related to: Issue #1659.
The segment for transpose scheduler:
g{(transpose)
inputs:
T22_g[ iS61{16}, iS62{128}, iS63{1600} ] float
T60_g[ iS175{16}, iS176{128}, iS253{1600}, rS254{i6} ] __bfloat
outputs:
T68_g[ iS200{16}, iS201{128}, iS202{1600} ] float
T61_l[ iS179{16}, iS180{128}, iS181{1600} ]
= rng_uniform_range({16, 128, 1600}, double(0), double(1), __bfloat);
(58)
T62_g[ iS182{16}, iS183{128}, iS184{1600} ]
= __bfloat2float(T61_l[ iS179{16}, iS180{128}, iS181{1600} ]);
(59)
T63_g[ iS185{16}, iS186{128}, iS187{1600} ]
= T62_g[ iS182{16}, iS183{128}, iS184{1600} ]
< double(0.90000000000000002);
(60)
T65_g[ iS191{16}, iS192{128}, iS193{1600} ]
= (float)(T63_g[ iS185{16}, iS186{128}, iS187{1600} ]);
(62)
T64_g[ iS188{16}, iS189{128}, iS255{1600} ]
= __bfloat2float(T60_g[ iS175{16}, iS176{128}, iS253{1600}, rS254{i6} ]);
(61)
T66_g[ iS194{16}, iS195{128}, iS196{1600} ]
= T64_g[ iS188{16}, iS189{128}, iS255{1600} ]
* T65_g[ iS191{16}, iS192{128}, iS193{1600} ];
(63)
T67_g[ iS197{16}, iS198{128}, iS199{1600} ]
= T66_g[ iS194{16}, iS195{128}, iS196{1600} ]
* double(1.11111);
(64)
T68_g[ iS200{16}, iS201{128}, iS202{1600} ]
= T22_g[ iS61{16}, iS62{128}, iS63{1600} ]
+ T67_g[ iS197{16}, iS198{128}, iS199{1600} ];
(65)
}
FYI, @Priya2698 , https://github.com/Lightning-AI/lightning-thunder/tree/wjy/bug2317 is the Thunder branch to reproduce the bug.
$ pytest thunder/benchmarks/targets.py -k test_nanogpt_block_fwd[thunder] -s
------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------ Captured log call ------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------
ERROR nvfuser:__init__.py:205 An error occurred while executing nvFuser FusionDefinition 1.
If you believe this is a bug or need assistance, please file an issue at https://github.com/NVIDIA/Fuser/issues/new
Here's a script to reproduce the error:
```python
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id1(fd : FusionDefinition) -> None :
T0 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T1 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T4 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T5 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
T6 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[0])
T7 = fd.define_tensor(shape=[-1, -1], contiguity=[True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[1, 0])
T8 = fd.define_tensor(shape=[-1, -1, -1], contiguity=[True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[2, 1, 0])
T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.BFloat16, is_cpu=False, stride_order=[3, 2, 1, 0])
T10 = fd.ops.permute(T9, dims=[0, 2, 1, 3])
T11 = fd.ops.stride_order(T10, stride_order=[3, 2, 1, 0])
S12 = fd.define_scalar(16, dtype=DataType.Int)
S13 = fd.define_scalar(128, dtype=DataType.Int)
S14 = fd.define_scalar(1600, dtype=DataType.Int)
V15 = fd.define_vector([S12, S13, S14], dtype=DataType.Int)
T16 = fd.ops.reshape(T11, new_shape=V15)
T17 = fd.ops.linear(T16, T1, T0)
S18 = fd.define_scalar(0.00000, dtype=DataType.Double)
S19 = fd.define_scalar(1.00000, dtype=DataType.Double)
S20 = fd.define_scalar(16, dtype=DataType.Int)
S21 = fd.define_scalar(128, dtype=DataType.Int)
S22 = fd.define_scalar(1600, dtype=DataType.Int)
V23 = fd.define_vector([S20, S21, S22], dtype=DataType.Int)
T24 = fd.ops.uniform(S18, S19, shape=V23, dtype=DataType.BFloat16)
S25 = fd.define_scalar(0.900000, dtype=DataType.Double)
T26 = fd.ops.lt(T24, S25)
T27 = fd.ops.cast(T17, dtype=DataType.Float)
T28 = fd.ops.cast(T26, dtype=DataType.Float)
T29 = fd.ops.mul(T27, T28)
S30 = fd.define_scalar(1.11111, dtype=DataType.Double)
T31 = fd.ops.mul(T29, S30)
T32 = fd.ops.cast(T8, dtype=DataType.Float)
T33 = fd.ops.add(T32, T31)
T34, T35 = fd.ops.var_mean(T33, dims=[2], correction=0, keepdim=False)
S36 = fd.define_scalar(16, dtype=DataType.Int)
S37 = fd.define_scalar(128, dtype=DataType.Int)
S38 = fd.define_scalar(1, dtype=DataType.Int)
V39 = fd.define_vector([S36, S37, S38], dtype=DataType.Int)
T40 = fd.ops.broadcast_in_dim(T34, shape=V39, broadcast_dims=[0, 1])
S41 = fd.define_scalar(16, dtype=DataType.Int)
S42 = fd.define_scalar(128, dtype=DataType.Int)
S43 = fd.define_scalar(1, dtype=DataType.Int)
V44 = fd.define_vector([S41, S42, S43], dtype=DataType.Int)
T45 = fd.ops.broadcast_in_dim(T35, shape=V44, broadcast_dims=[0, 1])
S46 = fd.define_scalar(1.00000e-05, dtype=DataType.Double)
T47 = fd.ops.add(T40, S46)
T48 = fd.ops.rsqrt(T47)
S49 = fd.define_scalar(16, dtype=DataType.Int)
S50 = fd.define_scalar(128, dtype=DataType.Int)
S51 = fd.define_scalar(1600, dtype=DataType.Int)
V52 = fd.define_vector([S49, S50, S51], dtype=DataType.Int)
T53 = fd.ops.broadcast_in_dim(T45, shape=V52, broadcast_dims=[0, 1, 2])
T54 = fd.ops.sub(T33, T53)
S55 = fd.define_scalar(16, dtype=DataType.Int)
S56 = fd.define_scalar(128, dtype=DataType.Int)
S57 = fd.define_scalar(1600, dtype=DataType.Int)
V58 = fd.define_vector([S55, S56, S57], dtype=DataType.Int)
T59 = fd.ops.broadcast_in_dim(T48, shape=V58, broadcast_dims=[0, 1, 2])
T60 = fd.ops.mul(T54, T59)
S61 = fd.define_scalar(16, dtype=DataType.Int)
S62 = fd.define_scalar(128, dtype=DataType.Int)
S63 = fd.define_scalar(1600, dtype=DataType.Int)
V64 = fd.define_vector([S61, S62, S63], dtype=DataType.Int)
T65 = fd.ops.broadcast_in_dim(T3, shape=V64, broadcast_dims=[2])
T66 = fd.ops.cast(T65, dtype=DataType.Float)
T67 = fd.ops.mul(T60, T66)
S68 = fd.define_scalar(16, dtype=DataType.Int)
S69 = fd.define_scalar(128, dtype=DataType.Int)
S70 = fd.define_scalar(1600, dtype=DataType.Int)
V71 = fd.define_vector([S68, S69, S70], dtype=DataType.Int)
T72 = fd.ops.broadcast_in_dim(T2, shape=V71, broadcast_dims=[2])
T73 = fd.ops.cast(T72, dtype=DataType.Float)
T74 = fd.ops.add(T67, T73)
T75 = fd.ops.cast(T74, dtype=DataType.BFloat16)
T76 = fd.ops.linear(T75, T5, T4)
T77 = fd.ops.cast(T76, dtype=DataType.Float)
T78 = fd.ops.mul(T77, T77)
T79 = fd.ops.mul(T78, T77)
S80 = fd.define_scalar(0.500000, dtype=DataType.Double)
T81 = fd.ops.mul(S80, T77)
S82 = fd.define_scalar(0.0447150, dtype=DataType.Double)
T83 = fd.ops.mul(S82, T79)
T84 = fd.ops.add(T77, T83)
S85 = fd.define_scalar(0.797885, dtype=DataType.Double)
T86 = fd.ops.mul(S85, T84)
T87 = fd.ops.tanh(T86)
S88 = fd.define_scalar(1.00000, dtype=DataType.Double)
T89 = fd.ops.add(S88, T87)
T90 = fd.ops.mul(T81, T89)
T91 = fd.ops.cast(T90, dtype=DataType.BFloat16)
T92 = fd.ops.linear(T91, T7, T6)
S93 = fd.define_scalar(0.00000, dtype=DataType.Double)
S94 = fd.define_scalar(1.00000, dtype=DataType.Double)
S95 = fd.define_scalar(16, dtype=DataType.Int)
S96 = fd.define_scalar(128, dtype=DataType.Int)
S97 = fd.define_scalar(1600, dtype=DataType.Int)
V98 = fd.define_vector([S95, S96, S97], dtype=DataType.Int)
T99 = fd.ops.uniform(S93, S94, shape=V98, dtype=DataType.BFloat16)
S100 = fd.define_scalar(0.900000, dtype=DataType.Double)
T101 = fd.ops.lt(T99, S100)
T102 = fd.ops.cast(T92, dtype=DataType.Float)
T103 = fd.ops.cast(T101, dtype=DataType.Float)
T104 = fd.ops.mul(T102, T103)
S105 = fd.define_scalar(1.11111, dtype=DataType.Double)
T106 = fd.ops.mul(T104, S105)
T107 = fd.ops.add(T33, T106)
T108 = fd.ops.cast(T107, dtype=DataType.BFloat16)
fd.add_output(T108)
with FusionDefinition() as fd:
nvfuser_fusion_id1(fd)
inputs = [
torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
torch.randn((2560000,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600, 1600), (1600, 1)),
torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
torch.randn((6400,), dtype=torch.bfloat16, device='cuda:0').as_strided((6400,), (1,)),
torch.randn((10240000,), dtype=torch.bfloat16, device='cuda:0').as_strided((6400, 1600), (1600, 1)),
torch.randn((1600,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600,), (1,)),
torch.randn((10240000,), dtype=torch.bfloat16, device='cuda:0').as_strided((1600, 6400), (6400, 1)),
torch.randn((3276800,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 128, 1600), (204800, 1600, 1)),
torch.randn((3276800,), dtype=torch.bfloat16, device='cuda:0').as_strided((16, 25, 128, 64), (204800, 8192, 64, 1)),
]
fd.execute(inputs)
Traceback (most recent call last): File "/opt/pytorch/nvfuser/nvfuser/init.py", line 145, in execute result = self._execute( RuntimeError: !detect_exception_in_thread_pool.load() INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/kernel_cache.cpp":1336, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Detected exception while compiling fusion segments in parallel. Error messages from all threads are printed below.
Error from segmentation group 5: Merging IterDomains requires that their iteration types match. Outer: iS284{( ceilDiv(1600, 32) )}, Inner: rS257{i6}
Exception raised from merge at /opt/pytorch/nvfuser/csrc/ir/nodes.cpp:2558 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const, char const, unsigned int, std::__cxx11::basic_string<char, std::char_traits
Use NVFUSER_DISABLE=parallel_compile to simplify error message.
Exception raised from compileFusionParallel at /opt/pytorch/nvfuser/csrc/kernel_cache.cpp:1336 (most recent call first):
frame #0: nvfuser::nvfCheckFail(char const, char const, unsigned int, std::cxx11::basic_string<char, std::char_traits
Update: never mind. I should have run test_nanogpt_block[inference-thunder]
instead. forward-thunder
is the forward pass in training, which is different from the previous benchmark.
~Interestingly, I'm unable to reproduce the problem after https://github.com/Lightning-AI/lightning-thunder/commit/d1b016a58a48e5c6282622de488be8c9135dd821, authored by @IvanYashchuk .~
$ git log
commit 81425474dad41afa1f3100efea63faa8fd062a68 (HEAD -> wjy/bug2317)
Author: Jingyue Wu <wujingyue@gmail.com>
Date: Thu May 9 20:39:36 2024 +0000
Unconditionally enable linear and matmul and turn off nv_enable_bookend.
commit d1b016a58a48e5c6282622de488be8c9135dd821
Author: Ivan Yashchuk <IvanYashchuk@users.noreply.github.com>
Date: Mon Jun 3 14:56:03 2024 +0300
Update benchmarks/targets.py: inference/forward/backward parametrization (#498)
commit 3bd0100218055ad3713466d8fc03647928f7a289
Author: Masaki Kozuki <mkozuki@nvidia.com>
Date: Mon Jun 3 18:19:57 2024 +0900
Remove trace dump for debug from test_tensor_parallel.py (#509)
$ pytest thunder/benchmarks/targets.py -k test_nanogpt_block[forward-thunder] -s
========================================================================================================================================================================================================================================= test session starts =========================================================================================================================================================================================================================================
platform linux -- Python 3.10.12, pytest-8.1.1, pluggy-1.5.0
Test order randomisation NOT enabled. Enable with --random-order or --random-order-bucket=<bucket_type>
benchmark: 4.0.0 (defaults: timer=torch.utils.benchmark.utils.timer.timer disable_gc=False min_rounds=5 min_time=0.000005 max_time=1.0 calibration_precision=10 warmup=True warmup_iterations=100000)
rootdir: /opt/pytorch/lightning-thunder
configfile: pyproject.toml
plugins: timestamper-0.0.10, xdist-3.5.0, random-order-1.1.1, cov-4.1.0, benchmark-4.0.0, hypothesis-6.100.0, timeout-2.2.0, anyio-4.3.0, shard-0.1.2
timeout: 900.0s
timeout method: signal
timeout func_only: False
collected 644 items / 643 deselected / 1 selected
Running 1 items in this shard
thunder/benchmarks/targets.py .
------------------------------------------------------ benchmark: 1 tests ------------------------------------------------------
Name (time in ms) Min Max Mean StdDev Median IQR Outliers OPS Rounds Iterations
--------------------------------------------------------------------------------------------------------------------------------
test_nanogpt_block[forward-thunder] 1.2142 123.1338 1.5354 4.2465 1.3960 0.0048 1;77 651.3001 822 1
--------------------------------------------------------------------------------------------------------------------------------
Legend:
Outliers: 1 Standard Deviation from Mean; 1.5 IQR (InterQuartile Range) from 1st Quartile and 3rd Quartile.
OPS: Operations Per Second, computed as 1 / Mean
========================================================================================================================================================================================================================== 1 passed, 643 deselected, 2376 warnings in 11.46s ==========================================================================================================================================================================================================================
FYI, @Priya2698 , I synced https://github.com/Lightning-AI/lightning-thunder/tree/wjy/bug2317. You can still reproduce the same problem using pytest thunder/benchmarks/targets.py -k test_nanogpt_block[inference-thunder] -s
.
This is happening when we segment at a reduction output. In the consumer segments, the edge is converted to an input that has a Reduction domain. Instead, I think we should filter out Reduction domains in convertInputRFactorsToRoots
(and adjust stride order accordingly).
Check out
wjy/linear
and runNVFUSER_DISABLE=parallel_compile python repro.py
.This happened after I rebased https://github.com/Lightning-AI/lightning-thunder/tree/wjy/sharded for https://github.com/NVIDIA/Fuser/issues/2199. I suspect 3D linear isn't not handled so well as reshape+2D_linear+reshape.