Closed tfogal closed 3 months ago
Can you try this again with TOT? It is not failing for me locally.
Can you try this again with TOT? It is not failing for me locally.
I'm running pretty close to ToT:
$ python3 -m pip freeze | grep -i nvfuse
nvfuser @ git+https://github.com/NVIDIA/Fuser.git@db95e48689ed640cff577c87ca3b0913c2d6989f
Are you perhaps using ampere? Now that I look, I also can't reproduce on my ampere-based workstation, only Hopper-based nodes. Apologies for not including originally; let me edit the bug to be clearer.
Are you perhaps using ampere? Now that I look, I also can't reproduce on my ampere-based workstation, only Hopper-based nodes. Apologies for not including originally; let me edit the bug to be clearer.
Ah, that's it. Thanks for the pointer! Will look into it more soon.
I have a smaller repro. This repro schedules with inner persistent scheduler without segmenting.
# CUDA devices:
# 0: NVIDIA H100 80GB HBM3
# torch version: 2.5.0a0+git8927fc2
# nvfuser version: 0.2.8+gitdd6886f
import torch
from nvfuser import FusionDefinition, DataType
def nvfuser_fusion_id0(fd : FusionDefinition) -> None :
S0 = fd.define_scalar(None, dtype=DataType.Int)
T1 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0])
T4 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
T5 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
T6 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0])
T7 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
T8 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
T10 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0])
T11 = fd.define_tensor(shape=[-1, -1, -1, -1, 1], contiguity=[True, True, True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[4, 3, 2, 1, 0])
T12 = fd.ops.sum(T11, dims=[4], keepdim=False, dtype=DataType.Null)
T13 = fd.ops.set(T12)
T14 = fd.ops.set(T12)
T15 = fd.ops.sum(T14, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
S16 = fd.define_scalar(1, dtype=DataType.Int)
S17 = fd.define_scalar(288, dtype=DataType.Int)
S18 = fd.define_scalar(1, dtype=DataType.Int)
S19 = fd.define_scalar(1, dtype=DataType.Int)
V20 = fd.define_vector([S16, S17, S18, S19], dtype=DataType.Int)
T21 = fd.ops.broadcast_in_dim(T15, shape=V20, broadcast_dims=[1])
T22 = fd.ops.set(T12)
T23 = fd.ops.sum(T22, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null)
T24 = fd.ops.broadcast_in_dim(T23, shape=V20, broadcast_dims=[1])
S25 = fd.define_scalar(288, dtype=DataType.Int)
V26 = fd.define_vector([S25], dtype=DataType.Int)
T27 = fd.ops.reshape(T24, new_shape=V26)
S28 = fd.define_scalar(288, dtype=DataType.Int)
V29 = fd.define_vector([S28], dtype=DataType.Int)
T30 = fd.ops.reshape(T21, new_shape=V29)
S31 = fd.define_scalar(-0.500000, dtype=DataType.Double)
T32 = fd.ops.mul(S31, T30)
S33 = fd.define_scalar(3.00000, dtype=DataType.Double)
T34 = fd.ops.pow(T3, S33)
T35 = fd.ops.mul(T32, T34)
T36 = fd.ops.broadcast_in_dim(T27, shape=V20, broadcast_dims=[1])
S37 = fd.define_scalar(2, dtype=DataType.Int)
S38 = fd.define_scalar(288, dtype=DataType.Int)
S39 = fd.define_scalar(120, dtype=DataType.Int)
S40 = fd.define_scalar(160, dtype=DataType.Int)
V41 = fd.define_vector([S37, S38, S39, S40], dtype=DataType.Int)
T42 = fd.ops.broadcast_in_dim(T36, shape=V41, broadcast_dims=[0, 1, 2, 3])
S43 = fd.define_scalar(2.60417e-05, dtype=DataType.Double)
T44 = fd.ops.mul(S43, T42)
T45 = fd.ops.broadcast_in_dim(T35, shape=V20, broadcast_dims=[1])
T46 = fd.ops.broadcast_in_dim(T45, shape=V41, broadcast_dims=[0, 1, 2, 3])
T47 = fd.ops.broadcast_in_dim(T2, shape=V20, broadcast_dims=[1])
S48 = fd.define_scalar(2.00000, dtype=DataType.Double)
T49 = fd.ops.mul(S48, T46)
T50 = fd.ops.sub(T1, T47)
T51 = fd.ops.mul(T49, T50)
S52 = fd.ops.cast(S0, dtype=DataType.Double)
S53 = fd.ops.reciprocal(S52)
T54 = fd.ops.mul(T51, S53)
T55 = fd.ops.add(T44, T54)
T56 = fd.ops.add(T13, T55)
T57 = fd.ops.cast(T56, dtype=DataType.Half)
fd.add_output(T57)
with FusionDefinition() as fd:
nvfuser_fusion_id0(fd)
inputs = [
38400,
torch.randn((11059200,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)),
torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((288,), (1,)),
torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((288,), (1,)),
torch.randn((11059200,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)),
torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160), (0, 1, 0, 0)),
torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160), (0, 1, 0, 0)),
torch.randint(0, 2, (11059200,), dtype=torch.bool, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)),
torch.randint(0, 2, (11059200,), dtype=torch.bool, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)),
torch.randint(0, 2, (11059200,), dtype=torch.bool, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)),
torch.randint(0, 2, (11059200,), dtype=torch.bool, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)),
torch.randn((11059200,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160, 1), (5529600, 19200, 160, 1, 1)),
]
fd.execute(inputs)
The scheduler params on H100 are:
===== Persistent Kernel Properties ========
inner_most_dimension_numel: 19200
total_reduction_numel: 38400
total_iteration_numel: 288
max_persistent_buffer_size: 153600
n_tensor_inputs: 2
max_input_dtype_size: 4
max allowed vectorize_factor: 4
project_persistent_buffers: 0
===== Reduction Parameters ========
Tag: Shared Memory Inner Persistent Heuristic.
Red On Fastest Dim
Persistent Kernel
Batches per block: 10
Iteration Domain: blockIdx.x /
Inner Reduction Domain: cross block - threadIdx.x / pad to warp / persistent batch - 10 / vectorize / factor 4
Launch Parameters: BlockDim.x = -1, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0
Compile Parameters: index_type = int, maxrregcount = 255, enable_magic_zero = 1, enable_ptxas_verbose = 0
====================================
===== Persistent Kernel Properties ========
inner_most_dimension_numel: 19200
total_reduction_numel: 38400
total_iteration_numel: 288
max_persistent_buffer_size: 153600
n_tensor_inputs: 2
max_input_dtype_size: 4
max allowed vectorize_factor: 4
project_persistent_buffers: 0
===== Reduction Parameters ========
Tag: Shared Memory Inner Persistent Heuristic.
Red On Fastest Dim
Persistent Kernel
Batches per block: 10
Iteration Domain: blockIdx.x /
Inner Reduction Domain: cross block - threadIdx.x / pad to warp / persistent batch - 10 / vectorize / factor 4
Launch Parameters: BlockDim.x = -1, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0
Compile Parameters: index_type = int, maxrregcount = 255, enable_magic_zero = 1, enable_ptxas_verbose = 0
cc @liqiangxl in case anything pops out at you about this fusion
BTW I just checked and the above smaller repro also errors on A100 with these params:
===== Persistent Kernel Properties ======== inner_most_dimension_numel: 19200
total_reduction_numel: 38400
total_iteration_numel: 288
max_persistent_buffer_size: 153600
n_tensor_inputs: 2
max_input_dtype_size: 4
max allowed vectorize_factor: 4
project_persistent_buffers: 0
===== Reduction Parameters ========
Tag: Shared Memory Inner Persistent Heuristic.
Red On Fastest Dim Persistent Kernel Batches per block: 10
Iteration Domain: blockIdx.x / Inner Reduction Domain: cross block - threadIdx.x / pad to warp / persistent batch - 10 / vectorize / factor 4
Launch Parameters: BlockDim.x = -1, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0 Compile Parameters: index_type = int, maxrregcount = 255, enable_magic_zero = 1, enable_ptxas_verbose = 0
====================================
===== Persistent Kernel Properties ========
inner_most_dimension_numel: 19200
total_reduction_numel: 38400
total_iteration_numel: 288
max_persistent_buffer_size: 153600
n_tensor_inputs: 2
max_input_dtype_size: 4
max allowed vectorize_factor: 4
project_persistent_buffers: 0
===== Reduction Parameters ========
Tag: Shared Memory Inner Persistent Heuristic.
Red On Fastest Dim
Persistent Kernel Batches per block: 10
Iteration Domain: blockIdx.x / Inner Reduction Domain: cross block - threadIdx.x / pad to warp / persistent batch - 10 / vectorize / factor 4
Launch Parameters: BlockDim.x = -1, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0 Compile Parameters: index_type = int, maxrregcount = 255, enable_magic_zero = 1, enable_ptxas_verbose = 0
====================================
Now that I look, I also can't reproduce on my ampere-based workstation, only Hopper-based nodes.
Ah, I think this is maybe a geforce vs tesla type of thing. Maybe the additional smem or higher SM counts could explain H100/A100 having the issue vs my 3090Ti and your workstation device not hitting the error.
CC @naoyam as the error is happening in loop promotion.
The fusion (Jacob's repro) generates a 3D inner persistent kernel, then uses shared memory persistent which is not supported yet. I added a WAR fix to disable the usage of shared memory for 3D inner persistent kernel. Will add this support and enable it later. Not sure why error happens in loop promotion.
The original fusion (Tom's repro) works fine with #2754 on both A100 and H100.
The original fusion (Tom's repro) works fine with #2754 on both A100 and H100.
Thanks for the fast help, @liqiangxl! I can confirm that your #2754 fixes this issue. 🎉 !
However the larger model (that this reproducer came from) dies a second later with the very similar error message Unsupported loop structure. Two loops are mapped together.bS201{1} and bS197{1}
. See attached next.py. Do you want to look at the next issue as part of #2754 or would you rather get #2754 in and I'll file a new issue for the subsequent loop structure issue?
The original fusion (Tom's repro) works fine with #2754 on both A100 and H100.
Thanks for the fast help, @liqiangxl! I can confirm that your #2754 fixes this issue. 🎉 !
However the larger model (that this reproducer came from) dies a second later with the very similar error message
Unsupported loop structure. Two loops are mapped together.bS201{1} and bS197{1}
. See attached next.py. Do you want to look at the next issue as part of #2754 or would you rather get #2754 in and I'll file a new issue for the subsequent loop structure issue?
Thanks for double check. #2754 fixes a real issue so we need to get it but it is not a fix to #2685. For this loop structure issue, we need more work and I don't think you need to create a new issue.
Thanks for double check. #2754 fixes a real issue so we need to get it but it is not a fix to #2685. For this loop structure issue, we need more work and I don't think you need to create a new issue.
Ack, understood. Thanks for your continued efforts! I'll look forward to the loop structure fix.
Hi @tfogal would you mind do another test of #2759? It is still a WAR fix. If you are busy, I can help with the test.
Hi @tfogal would you mind do another test of #2759? It is still a WAR fix. If you are busy, I can help with the test.
I'm very excited so I had already started it running beforehand :-)
RuntimeError: frontier.size() == logical.size() INTERNAL ASSERT FAILED at "/tmp/pip-req-build-uqgrzxrb/csrc/executor.cpp":901, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues.
Exception raised from transformOutputFromAllocationToLogical at /tmp/pip-req-build-uqgrzxrb/csrc/executor.cpp:901 (most recent call first):
which sounds like a separate problem. I can file a new issue for that one.
new issue
Hi @tfogal would you mind do another test of #2759? It is still a WAR fix. If you are busy, I can help with the test.
I'm very excited so I had already started it running beforehand :-)
2759 does indeed fix the issue. The larger program now dies with:
RuntimeError: frontier.size() == logical.size() INTERNAL ASSERT FAILED at "/tmp/pip-req-build-uqgrzxrb/csrc/executor.cpp":901, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Exception raised from transformOutputFromAllocationToLogical at /tmp/pip-req-build-uqgrzxrb/csrc/executor.cpp:901 (most recent call first):
which sounds like a separate problem. I can file a new issue for that one.
Thanks for the check. Let me know when the new issue is created and I'll check whether it is related to #2759 or somethingelse.
which sounds like a separate problem. I can file a new issue for that one.
Thanks for the check. Let me know when the new issue is created and I'll check whether it is related to #2759 or somethingelse.
Thanks! I filed #2760.
@naoyam is trying to root cause this issue. The WAR doesn't seem sufficient to me given we don't know what the actual issue is. @tfogal if the WAR's provided help you make forward progress great. @liqiangxl please don't merge them in unless we know the root cause. It also doesn't seem like your WAR would last very long if we continued to try a failing case given the line I comment don.
I have a smaller repro. This repro schedules with inner persistent scheduler without segmenting.
# CUDA devices: # 0: NVIDIA H100 80GB HBM3 # torch version: 2.5.0a0+git8927fc2 # nvfuser version: 0.2.8+gitdd6886f import torch from nvfuser import FusionDefinition, DataType def nvfuser_fusion_id0(fd : FusionDefinition) -> None : S0 = fd.define_scalar(None, dtype=DataType.Int) T1 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0]) T2 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0]) T3 = fd.define_tensor(shape=[-1], contiguity=[True], dtype=DataType.Float, is_cpu=False, stride_order=[0]) T4 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0]) T5 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0]) T6 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[None, True, None, None], dtype=DataType.Float, is_cpu=False, stride_order=[3, 2, 1, 0]) T7 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0]) T8 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0]) T9 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0]) T10 = fd.define_tensor(shape=[-1, -1, -1, -1], contiguity=[True, True, True, True], dtype=DataType.Bool, is_cpu=False, stride_order=[3, 2, 1, 0]) T11 = fd.define_tensor(shape=[-1, -1, -1, -1, 1], contiguity=[True, True, True, True, None], dtype=DataType.Float, is_cpu=False, stride_order=[4, 3, 2, 1, 0]) T12 = fd.ops.sum(T11, dims=[4], keepdim=False, dtype=DataType.Null) T13 = fd.ops.set(T12) T14 = fd.ops.set(T12) T15 = fd.ops.sum(T14, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null) S16 = fd.define_scalar(1, dtype=DataType.Int) S17 = fd.define_scalar(288, dtype=DataType.Int) S18 = fd.define_scalar(1, dtype=DataType.Int) S19 = fd.define_scalar(1, dtype=DataType.Int) V20 = fd.define_vector([S16, S17, S18, S19], dtype=DataType.Int) T21 = fd.ops.broadcast_in_dim(T15, shape=V20, broadcast_dims=[1]) T22 = fd.ops.set(T12) T23 = fd.ops.sum(T22, dims=[0, 2, 3], keepdim=False, dtype=DataType.Null) T24 = fd.ops.broadcast_in_dim(T23, shape=V20, broadcast_dims=[1]) S25 = fd.define_scalar(288, dtype=DataType.Int) V26 = fd.define_vector([S25], dtype=DataType.Int) T27 = fd.ops.reshape(T24, new_shape=V26) S28 = fd.define_scalar(288, dtype=DataType.Int) V29 = fd.define_vector([S28], dtype=DataType.Int) T30 = fd.ops.reshape(T21, new_shape=V29) S31 = fd.define_scalar(-0.500000, dtype=DataType.Double) T32 = fd.ops.mul(S31, T30) S33 = fd.define_scalar(3.00000, dtype=DataType.Double) T34 = fd.ops.pow(T3, S33) T35 = fd.ops.mul(T32, T34) T36 = fd.ops.broadcast_in_dim(T27, shape=V20, broadcast_dims=[1]) S37 = fd.define_scalar(2, dtype=DataType.Int) S38 = fd.define_scalar(288, dtype=DataType.Int) S39 = fd.define_scalar(120, dtype=DataType.Int) S40 = fd.define_scalar(160, dtype=DataType.Int) V41 = fd.define_vector([S37, S38, S39, S40], dtype=DataType.Int) T42 = fd.ops.broadcast_in_dim(T36, shape=V41, broadcast_dims=[0, 1, 2, 3]) S43 = fd.define_scalar(2.60417e-05, dtype=DataType.Double) T44 = fd.ops.mul(S43, T42) T45 = fd.ops.broadcast_in_dim(T35, shape=V20, broadcast_dims=[1]) T46 = fd.ops.broadcast_in_dim(T45, shape=V41, broadcast_dims=[0, 1, 2, 3]) T47 = fd.ops.broadcast_in_dim(T2, shape=V20, broadcast_dims=[1]) S48 = fd.define_scalar(2.00000, dtype=DataType.Double) T49 = fd.ops.mul(S48, T46) T50 = fd.ops.sub(T1, T47) T51 = fd.ops.mul(T49, T50) S52 = fd.ops.cast(S0, dtype=DataType.Double) S53 = fd.ops.reciprocal(S52) T54 = fd.ops.mul(T51, S53) T55 = fd.ops.add(T44, T54) T56 = fd.ops.add(T13, T55) T57 = fd.ops.cast(T56, dtype=DataType.Half) fd.add_output(T57) with FusionDefinition() as fd: nvfuser_fusion_id0(fd) inputs = [ 38400, torch.randn((11059200,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)), torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((288,), (1,)), torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((288,), (1,)), torch.randn((11059200,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)), torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160), (0, 1, 0, 0)), torch.randn((288,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160), (0, 1, 0, 0)), torch.randint(0, 2, (11059200,), dtype=torch.bool, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)), torch.randint(0, 2, (11059200,), dtype=torch.bool, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)), torch.randint(0, 2, (11059200,), dtype=torch.bool, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)), torch.randint(0, 2, (11059200,), dtype=torch.bool, device='cuda:0').as_strided((2, 288, 120, 160), (5529600, 19200, 160, 1)), torch.randn((11059200,), dtype=torch.float32, device='cuda:0').as_strided((2, 288, 120, 160, 1), (5529600, 19200, 160, 1, 1)), ] fd.execute(inputs)
The scheduler params on H100 are:
===== Persistent Kernel Properties ======== inner_most_dimension_numel: 19200 total_reduction_numel: 38400 total_iteration_numel: 288 max_persistent_buffer_size: 153600 n_tensor_inputs: 2 max_input_dtype_size: 4 max allowed vectorize_factor: 4 project_persistent_buffers: 0 ===== Reduction Parameters ======== Tag: Shared Memory Inner Persistent Heuristic. Red On Fastest Dim Persistent Kernel Batches per block: 10 Iteration Domain: blockIdx.x / Inner Reduction Domain: cross block - threadIdx.x / pad to warp / persistent batch - 10 / vectorize / factor 4 Launch Parameters: BlockDim.x = -1, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0 Compile Parameters: index_type = int, maxrregcount = 255, enable_magic_zero = 1, enable_ptxas_verbose = 0 ==================================== ===== Persistent Kernel Properties ======== inner_most_dimension_numel: 19200 total_reduction_numel: 38400 total_iteration_numel: 288 max_persistent_buffer_size: 153600 n_tensor_inputs: 2 max_input_dtype_size: 4 max allowed vectorize_factor: 4 project_persistent_buffers: 0 ===== Reduction Parameters ======== Tag: Shared Memory Inner Persistent Heuristic. Red On Fastest Dim Persistent Kernel Batches per block: 10 Iteration Domain: blockIdx.x / Inner Reduction Domain: cross block - threadIdx.x / pad to warp / persistent batch - 10 / vectorize / factor 4 Launch Parameters: BlockDim.x = -1, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0 Compile Parameters: index_type = int, maxrregcount = 255, enable_magic_zero = 1, enable_ptxas_verbose = 0
Here're all the tensors after scheduling:
T47_l[ iblockIdx.x243{288}, ithreadIdx.x296{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS293{10}, iUS295{1}, iV292{4}, bS246{1} ] ca_pos( 4 )
T11_s[ iblockIdx.x161{288}, ithreadIdx.x288{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS285{10}, iUS287{1}, iS284{4} ] ca_pos( 1 ) produce_pos( 4 )
T12_l[ iblockIdx.x165{288}, ithreadIdx.x328{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS325{10}, iUS327{1}, iS324{4} ] ca_pos( 5 ) produce_pos( 1 )
T17_l[ iblockIdx.x169{288}, ithreadIdx.x280{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS277{10}, iUS279{1}, iS276{4} ] ca_pos( 5 ) produce_pos( 1 )
T49_l[ iblockIdx.x260{288}, ithreadIdx.x270{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}rf_p, rS267{10}rf, rUS269{1}rf, rS266{4}rf ] ca_pos( 2 ) produce_pos( 5 )
T18_l[ iblockIdx.x272{288}, rthreadIdx.x271{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p ] ca_pos( 1 ) produce_pos( 2 )
T19_l[ iblockIdx.x176{288}, bthreadIdx.x488{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS485{10}, bUS487{1}, bS484{4} ] ca_pos( 1 ) produce_pos( 1 )
T20_l[ iblockIdx.x76{288}, bthreadIdx.x480{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS477{10}, bUS479{1}, bS476{4} ] ca_pos( 1 ) produce_pos( 1 )
T27_l[ iblockIdx.x89{288}, bthreadIdx.x472{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS469{10}, bUS471{1}, bS468{4} ] ca_pos( 1 ) produce_pos( 1 )
T28_l[ iblockIdx.x93{288}, bthreadIdx.x464{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS461{10}, bUS463{1}, bS460{4} ] ca_pos( 1 ) produce_pos( 1 )
T29_l[ iblockIdx.x97{288}, bthreadIdx.x456{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) ) ex ( ceilDiv(( ceilDiv(( ceilDiv(( 2 * ( 120 * 160 ) ), 4) ), 10) ), 1) )}_p, bS453{10}, bUS455{1}, bS452{4} ] ca_pos( 1 ) produce_pos( 1 ) = expand( T28_l[ iblockIdx.x93{288}, bthreadIdx.x464{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS461{10}, bUS463{1}, bS460{4} ] ca_pos( 1 ) produce_pos( 1 ), {2, 288, 120, 160} )
T30_l[ iblockIdx.x101{288}, bthreadIdx.x448{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) ) ex ( ceilDiv(( ceilDiv(( ceilDiv(( 2 * ( 120 * 160 ) ), 4) ), 10) ), 1) )}_p, bS445{10}, bUS447{1}, bS444{4} ] ca_pos( 1 ) produce_pos( 1 )
T13_l[ iblockIdx.x178{288}, ithreadIdx.x312{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS309{10}, iUS311{1}, iS308{4} ] ca_pos( 5 ) produce_pos( 1 )
T50_l[ iblockIdx.x506{288}, ithreadIdx.x516{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}rf_p, rS513{10}rf, rUS515{1}rf, rS512{4}rf ] ca_pos( 2 ) produce_pos( 5 )
T14_l[ iblockIdx.x518{288}, rthreadIdx.x517{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p ] ca_pos( 1 ) produce_pos( 2 )
T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 )
T16_l[ iblockIdx.x60{288}, bS59{1}, bS61{1}, bS62{1} ] ca_pos( 1 )
T22_l[ iblockIdx.x80{288} ] ca_pos( 1 ) produce_pos( 1 )
T23_l[ iblockIdx.x81{288} ] ca_pos( 1 ) produce_pos( 1 )
T46_l[ iblockIdx.x241{288} ] ca_pos( 1 )
T24_l[ iblockIdx.x187{288} ] ca_pos( 1 ) produce_pos( 1 )
T25_l[ iblockIdx.x83{288} ] ca_pos( 1 ) produce_pos( 1 )
T31_l[ iblockIdx.x105{288}, bthreadIdx.x440{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS437{10}, bUS439{1}, bS436{4} ] ca_pos( 1 ) produce_pos( 1 )
T32_l[ iblockIdx.x109{288}, bthreadIdx.x432{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS429{10}, bUS431{1}, bS428{4} ] ca_pos( 1 ) produce_pos( 1 )
T33_l[ iblockIdx.x113{288}, bthreadIdx.x424{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS421{10}, bUS423{1}, bS420{4} ] ca_pos( 1 ) produce_pos( 1 )
T34_l[ iblockIdx.x117{288}, bthreadIdx.x416{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) ) ex ( ceilDiv(( ceilDiv(( ceilDiv(( 2 * ( 120 * 160 ) ), 4) ), 10) ), 1) )}_p, bS413{10}, bUS415{1}, bS412{4} ] ca_pos( 1 ) produce_pos( 1 ) = expand( T33_l[ iblockIdx.x113{288}, bthreadIdx.x424{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS421{10}, bUS423{1}, bS420{4} ] ca_pos( 1 ) produce_pos( 1 ), {2, 288, 120, 160} )
T37_l[ iblockIdx.x129{288}, bthreadIdx.x408{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) ) ex ( ceilDiv(( ceilDiv(( ceilDiv(( 2 * ( 120 * 160 ) ), 4) ), 10) ), 1) )}_p, bS405{10}, bUS407{1}, bS404{4} ] ca_pos( 1 ) produce_pos( 1 )
T44_l[ iblockIdx.x237{288}, ithreadIdx.x392{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS389{10}, iUS391{1}, iV388{4} ] ca_pos( 4 )
T45_l[ iblockIdx.x240{288} ] ca_pos( 1 )
T35_l[ iblockIdx.x190{288}, bthreadIdx.x384{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS381{10}, bUS383{1}, bS380{4} ] ca_pos( 1 ) produce_pos( 1 )
T36_l[ iblockIdx.x125{288}, bthreadIdx.x376{( ceilDiv(( ceilDiv(( ceilDiv(( 1 * ( 1 * 1 ) ), 4) ), 10) ), 1) )}_p, bS373{10}, bUS375{1}, bS372{4} ] ca_pos( 1 ) produce_pos( 1 )
T38_l[ iblockIdx.x133{288}, ithreadIdx.x368{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS365{10}, iUS367{1}, iS364{4} ] ca_pos( 5 ) produce_pos( 4 )
T39_l[ iblockIdx.x137{288}, ithreadIdx.x360{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS357{10}, iUS359{1}, iS356{4} ] ca_pos( 5 ) produce_pos( 5 )
d163 = (double)(i0);
d165 = reciprocal(d163);
T40_l[ iblockIdx.x141{288}, ithreadIdx.x352{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS349{10}, iUS351{1}, iS348{4} ] ca_pos( 5 ) produce_pos( 5 )
T41_l[ iblockIdx.x145{288}, ithreadIdx.x344{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS341{10}, iUS343{1}, iS340{4} ] ca_pos( 5 ) produce_pos( 5 )
T42_l[ iblockIdx.x149{288}, ithreadIdx.x336{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS333{10}, iUS335{1}, iS332{4} ] ca_pos( 5 ) produce_pos( 5 )
T48_l[ iblockIdx.x153{288}, ithreadIdx.x496{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS493{10}, iUS495{1}, iS492{4} ] ca_pos( 4 ) produce_pos( 5 )
T43_g[ iblockIdx.x248{288}, ithreadIdx.x504{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p, iS501{10}, iUS503{1}, iV500{4} ] ca_pos( 4 ) produce_pos( 4 )
All tensors appear to be scheduled consistently, except for T15
, which has a broadcast domain at the outermost position: T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 )
. This results in the self-mapping validation error. See #2762 for a simplified repro. As mentioned there, I believe it's a false alarm but just removing the validation doesn't seem to work. Also, the repro seems to work fine if the new indexer is used.
I'm not sure why T15
is scheduled that way. If it's ordered as [ iblockIdx.x185{288}, bS55{1}, bS57{1}, bS58{1} ]
, I believe even the legacy indexer should work.
@liqiangxl, is it because of #2754?
Suppose it's indeed due to #2754, disabling the scheduler seems to be the first thing we should do. Although it should not result in an error, since it doesn't fail with the new indexer, I don't think it's worthwhile to fix the legacy indexer.
Also, as a clarification, the original error doesn't have anything specific to loop promotion. The problem is self mapping, which in this case should be a false alarm.
Actually, the smem support doesn't seem to be directly related. #2754 still results in generating a similar scheduled fusion pattern with Tom's next.py:
T50_l[ bS197{1}, iblockIdx.x198{288}, bS199{1}, bS200{1} ] produce_pos( 2 )
whereas the rest of the tensors look like:
T51_l[ iblockIdx.x202{288}, bS201{1}, bS203{1}, bS204{1} ] ca_pos( 1 )
This seems to indicate there's indeed some problem with the inner persistent scheduler, no matter if smem is used.
@liqiangxl What exactly did you mean smem persistent not supported yet?
bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ]
Actually, the smem support doesn't seem to be directly related. #2754 still results in generating a similar scheduled fusion pattern with Tom's next.py:
T50_l[ bS197{1}, iblockIdx.x198{288}, bS199{1}, bS200{1} ] produce_pos( 2 )
whereas the rest of the tensors look like:
T51_l[ iblockIdx.x202{288}, bS201{1}, bS203{1}, bS204{1} ] ca_pos( 1 )
This seems to indicate there's indeed some problem with the inner persistent scheduler, no matter if smem is used.
@liqiangxl What exactly did you mean smem persistent not supported yet?
(1) For 3D reduction we use the following schedule, which is not set in smem persistent.
if (rparams->schedule_3D) {
rparams->batches_per_block_outer_reduction =
batches_per_block_outer_reduction;
rparams->block_dim_outer_reduction = ParallelType::TIDz;
rparams->cross_block_outer_reduction = true;
rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor;
}
(2) You are right, this issue is not related to smem persistent, disable it fixes the original issue is only becuase it changed the segmentation results and bypassed the pattern which causes the error.
(3) Further investigation found the inconsistent schedule of T15 comes from propagateTransformation
,
After it, T15 is changed to T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ]
and we have:
T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 )
= broadcast( T14_l[ iblockIdx.x518{288}, rthreadIdx.x517{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p ]
T16_l[ iblockIdx.x60{288}, bS59{1}, bS61{1}, bS62{1} ] ca_pos( 1 )
= Set( T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 ), cache_op=Streaming )
T15 after propagateTransformation
T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 )
logical domain : (bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1})
allocation domain : (iblockIdx.x185{288}, bS55{1}, bS57{1}, bS58{1})
contiguity: t n n n
loop domain : (bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1})
Durining propagateTransformation
TransformPropagator::propagateP2C
from: T14_l[ iS182{288}, rS320{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}, rS317{10}, rS319{1}, rS316{4} ] @ 5
to: T15_l[ bS55{1}, iS185{288}, bS57{1}, bS58{1} ]
replay skipped. result position: 2
bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ]
Actually, the smem support doesn't seem to be directly related. #2754 still results in generating a similar scheduled fusion pattern with Tom's next.py:
T50_l[ bS197{1}, iblockIdx.x198{288}, bS199{1}, bS200{1} ] produce_pos( 2 )
whereas the rest of the tensors look like:
T51_l[ iblockIdx.x202{288}, bS201{1}, bS203{1}, bS204{1} ] ca_pos( 1 )
This seems to indicate there's indeed some problem with the inner persistent scheduler, no matter if smem is used. @liqiangxl What exactly did you mean smem persistent not supported yet?
(1) For 3D reduction we use the following schedule, which is not set in smem persistent.
if (rparams->schedule_3D) { rparams->batches_per_block_outer_reduction = batches_per_block_outer_reduction; rparams->block_dim_outer_reduction = ParallelType::TIDz; rparams->cross_block_outer_reduction = true; rparams->unroll_factor_outer_reduction = outer_reduction_unroll_factor; }
So, does that mean we are currently generating invalid heuristics? I'm not sure what consequence it would have, but sounds like that needs to be addressed quickly.
(2) You are right, this issue is not related to smem persistent, disable it fixes the original issue is only becuase it changed the segmentation results and bypassed the pattern which causes the error.
(3) Further investigation found the inconsistent schedule of T15 comes from
propagateTransformation
, After it, T15 is changed toT15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ]
and we have:T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 ) = broadcast( T14_l[ iblockIdx.x518{288}, rthreadIdx.x517{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p ] T16_l[ iblockIdx.x60{288}, bS59{1}, bS61{1}, bS62{1} ] ca_pos( 1 ) = Set( T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 ), cache_op=Streaming )
T15 after
propagateTransformation
T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 ) logical domain : (bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1}) allocation domain : (iblockIdx.x185{288}, bS55{1}, bS57{1}, bS58{1}) contiguity: t n n n loop domain : (bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1})
Durining
propagateTransformation
TransformPropagator::propagateP2C from: T14_l[ iS182{288}, rS320{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}, rS317{10}, rS319{1}, rS316{4} ] @ 5 to: T15_l[ bS55{1}, iS185{288}, bS57{1}, bS58{1} ] replay skipped. result position: 2
I'll look into it.
Durining
propagateTransformation
TransformPropagator::propagateP2C from: T14_l[ iS182{288}, rS320{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}, rS317{10}, rS319{1}, rS316{4} ] @ 5 to: T15_l[ bS55{1}, iS185{288}, bS57{1}, bS58{1} ] replay skipped. result position: 2
I'll look into it.
Seems due to getMatchedLeafPosWithoutReplayCasP
, see https://github.com/NVIDIA/Fuser/pull/2764
Here's I believe the root cause of the issue. (Thanks @zasdfgbnm for the discussion on transform propagation)
This is an inner normalization fusion, but in addition to the usual normalization pattern, there's also a broadcast domain that's just squeezed without concretization. As usual, we use the reduction tensor as the reference tensor, but it knows nothing about the squeezed non-concretized broadcast domain. While that may not matter in some case, in the case of the original repro, the non-concretized broadcast domain results in the non-consistent ordering as seen below:
T14_l[ iblockIdx.x518{288}, rthreadIdx.x517{( ceilDiv(( ceilDiv(( ceilDiv(( i1 * ( i3 * i4 ) ), 4) ), 10) ), 1) )}_p ] ca_pos( 1 ) produce_pos( 2 )
T15_l[ bS55{1}, iblockIdx.x185{288}, bS57{1}, bS58{1} ] produce_pos( 2 )
T16_l[ iblockIdx.x60{288}, bS59{1}, bS61{1}, bS62{1} ] ca_pos( 1 )
If T15
were ordered as [ iblockIdx.x185{288}, bS55{1}, , bS57{1}, bS58{1} ]
, the inner normalization scheduler should just be able to schedule it.
I thought #2765 could work around the issue, but, while it's sufficient for this particular fusion, we could come up with other fusions that might not work as we would like. In fact, since such broadcast domains are not represented by reference, it doesn't seem well defined how they should be transformed, and I'm not sure if there's a generic algorithm to find a right transformation.
Instead of trying to transform those non-concretized broadcast domains, I think it makes more sense to just ignore or remove them from the fusion. It would solve the propagation issue without changing the semantics of the original user fusion. To detect such safe-to-remove broadcast domains, I think we can use the Permissive graph. If a Permissive group only has no non-broadcast domain, the group should only have non-concretized broadcast domains.
As long as #2765 results in other failures, I believe it's a strict improvement, so we could move forward with it as a short-term workaround.
We should consider a pre-segmentation pass to remove non-concretized broadcast domains.
https://github.com/NVIDIA/Fuser/pull/2765 seems like a great generic improvement, whether or not it does justice here.
Mentioned in a slack with @naoyam and he's trying out another approach which we're hopeful will work, which is to simply move any non-concretized broadcasts to the inner most dimension in scheduling. These dimensions can be found easily with the permissive ID graph. If we move them to the inner most dimension in pre-scheduling and the reference does not have these dimensions represented, then they won't get moved during propagation.
@naoyam is trying this out and should be straightforward relative to the other approaches.
The given program crashes on Hopper with the error:
Full program
Interestingly the program runs just fine on Ampere.