NVIDIA / Fuser

A Fusion Code Generator for NVIDIA GPUs (commonly known as "nvFuser")
Other
271 stars 53 forks source link

Make it more explicit about the rfactor flag propagation #3323

Closed naoyam closed 3 weeks ago

naoyam commented 3 weeks ago

IterDomain::split has a default value of false for the rfactor flag, which means if an rfactor iter domain is split, it would generate non-rfactor output domains. I'm not sure if that's the right default behavior, so this PR removes the default value and make the behavior more explicit.

Nothing should be affected by this change, as long as all the tests and benchmarks don't fail.

naoyam commented 3 weeks ago

!build

naoyam commented 3 weeks ago

Failed cases:

NvFuserScheduler_Matmul_Manual/nvfuser_splitk_TT/M:1024/N:1024/K:50304/warps:8/stages:3/splitk_factor:2/smem_epilogue:0/manual_time for 1
00:00:10 terminate called after throwing an instance of 'nvfuser::nvfError'
00:00:10   what():   INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/ir/internal_base_nodes.h":147, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Unexpected input iter domain. Input should not be an rfactor iter domain: rS196{32}rf
00:00:10 Exception raised from split at /opt/pytorch/nvfuser/csrc/ir/internal_base_nodes.h:147 (most recent call first):

0:00:06 -- LOG(2): Running NvFuserScheduler_Matmul_Manual/nvfuser_splitk_TT/M:136/N:184/K:175704/warps:4/stages:3/splitk_factor:2/smem_epilogue:0/manual_time for 1
00:00:06 terminate called after throwing an instance of 'nvfuser::nvfError'
00:00:06   what():   INTERNAL ASSERT FAILED at "/opt/pytorch/nvfuser/csrc/ir/internal_base_nodes.h":147, please report a bug with repro script to NVFuser at https://github.com/NVIDIA/Fuser/issues. Unexpected input iter domain. Input should not be an rfactor iter domain: rS196{32}rf
00:00:06 Exception raised from split at /opt/pytorch/nvfuser/csrc/ir/internal_base_nodes.h:147 (most recent call first):
naoyam commented 3 weeks ago

Actually, never mind. Rfactor should be only for between root and logical, so in normal cases the rfactor flag should be propagated to outputs.