Closed rdspring1 closed 1 year ago
Python reproducer:
@torch.jit.script
def issue_1843(x: torch.Tensor, s: torch.Tensor):
return x.to(torch.float) + s.to(torch.float)
n = 2
h = w = 28
out_c = 61
in_c = 50
channels_last_permute_order = (0, 3, 2, 1)
# x => size = ([2, 61, 28, 28]) and strides = (47824, 1, 1708, 61)
x = torch.randn(n, h, w, out_c, dtype=torch.half, device='cuda').permute(channels_last_permute_order)
# s => size = [2, 50, 28, 28] and strides = (39200, 1, 1400, 50)
s = torch.randn(n, h, w, in_c, dtype=torch.half, device='cuda').permute(channels_last_permute_order)
x_hat = x[:, :in_c, :, :]
for _ in range(5):
output = issue_1843(x_hat, s)
Torchscript graph:
with prim::CudaFusionGroup_0 = graph(%4 : Half(2, 50, 28, 28, strides=[39200, 1, 1400, 50], requires_grad=0, device=cuda:0),
%8 : Half(2, 50, 28, 28, strides=[47824, 1, 1708, 61], requires_grad=0, device=cuda:0)):
%2 : int = prim::Constant[value=1]()
%10 : bool = prim::Constant[value=0]()
%9 : bool = prim::Constant[value=1]()
%11 : Float(2, 50, 28, 28, strides=[39200, 1, 1400, 50], requires_grad=0, device=cuda:0) = aten::_autocast_to_full_precision(%8, %9, %10)
%7 : Float(2, 50, 28, 28, strides=[39200, 1, 1400, 50], requires_grad=0, device=cuda:0) = aten::_autocast_to_full_precision(%4, %9, %10)
%3 : Float(2, 50, 28, 28, strides=[39200, 1, 1400, 50], requires_grad=0, device=cuda:0) = aten::add(%11, %7, %2) # /home/rds4/timm/timm/models/rexnet.py:97:27
Fusion:
g{(pointwise)
inputs:
T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ] __half
T2_g[ iS8{i13}, iS9{i14}, iS10{i15}, iS11{i16} ] __half
outputs:
T6_g[ iS24{i13}, iS25{i14}, iS26{i15}, iS27{i16} ] float
T3_l[ iS12{i13}, iS13{i14}, iS14{i15}, iS15{i16} ]
= __half2float(T2_g[ iS8{i13}, iS9{i14}, iS10{i15}, iS11{i16} ]);
T4_l[ iS16{i13}, iS17{i14}, iS18{i15}, iS19{i16} ]
= T3_l[ iS12{i13}, iS13{i14}, iS14{i15}, iS15{i16} ];
T1_l[ iS4{i0}, iS5{i2}, iS6{i3}, iS7{i4} ]
= __half2float(T0_g[ iS0{i0}, iS1{i2}, iS2{i3}, iS3{i4} ]);
T5_l[ iS20{i0}, iS21{i2}, iS22{i3}, iS23{i4} ]
= T1_l[ iS4{i0}, iS5{i2}, iS6{i3}, iS7{i4} ];
T6_g[ iS24{i13}, iS25{i14}, iS26{i15}, iS27{i16} ]
= T4_l[ iS16{i13}, iS17{i14}, iS18{i15}, iS19{i16} ]
+ T5_l[ iS20{i0}, iS21{i2}, iS22{i3}, iS23{i4} ];
}
It seems that the validation failure is a true failure.
Running the repro with PYTORCH_NVFUSER_DUMP=fusion_args
:
Arguments for fusion1:
Inputs:
Half [2, 28, 28, 50] (strides = [39200, 1400, 50, 1])
Half [2, 28, 28, 50] (strides = [47824, 1708, 61, 1])
The executor failure happens with the second argument as it can't vectorized with vector width being 2. The executor validation checks the strides of all dimensions, which is necessary, whereas the pointwise scheduler only looks at the domain extent of the inner-most domain, which is 50, so it determines it's safe to vectorize it by 2, which isn't really safe as the stride is 61.
To fix this, the scheduler should also need to look at domain strides. If for some reason we don't want to use the actual stride values to determine a vectorization width, we would need to make a conservative decision by disabling vectorization if the next (left) domain is not contiguous.
Can we use actual stride values rather than just contiguous or not when determining the vectorization strategy? https://github.com/csarofeen/pytorch/blob/devel/torch/csrc/jit/codegen/cuda/scheduler/registry.cpp#L576-L667
Pinging @shmsong @csarofeen
Need to fix vectorization: registry.cpp checking for vectorization
@zasdfgbnm could you take a look at this issue?
As of 10/4/2022, these are the TIMM benchmarks which are failing:
hugging face is also seeing a lot of these from our devel branch :cry:
This is long overdue. Let me see if I can get a fix today.
A smaller C++ repro
TEST_F(NVFuserTest, FusionVectorizeStrideContiguity_CUDA) {
std::unique_ptr<Fusion> fusion_ptr = std::make_unique<Fusion>();
auto fusion = fusion_ptr.get();
FusionGuard fg(fusion);
TensorView* tv0 =
TensorViewBuilder().ndims(2).contiguity({false, true}).build();
fusion->addInput(tv0);
auto tv1 = set(tv0);
fusion->addOutput(tv1);
auto options = at::TensorOptions().dtype(kFloat).device(at::kCUDA, 0);
at::Tensor t0 = at::randn({1000000, 17}, options).narrow(1, 0, 16);
FusionExecutorCache fec(std::move(fusion_ptr));
auto cg_outputs = fec.runFusionWithInputs({t0});
fec.fusion()->print();
testValidate(fusion, cg_outputs, {t0}, {t0}, __LINE__, __FILE__);
}
I'm still seeing this error popping up. I'm still working on getting a clean repro for you guys to act on. Meanwhile, want to have the issue re-open so we know what's coming.
Somehow a standalone kernel with the right input doesn't trigger the failure. The original error comes from a segmented graph. I tried to mimic it with a python example (with nvprim, which should give us the same code path), but that didn't give me the segmented kernel.
Here's the failed graph:
Segmented_Fusion Dump: -- fusion segments:
Segmented_Fusion{
groups:
g{0, 2, 3, 4, 5}
g{1, 6, 7}
edges:
e{ g{0, 2, 3, 4, 5}
-> g{1, 6, 7}
(T7_g[ iS14{i0}, bS15{32128} ]) }
group details:
g{(reduction)
inputs:
T0_g[ iS0{i0}, iS1{i1} ] float
i0 int64_t
outputs:
T7_g[ iS14{i0}, bS15{32128} ] float
T2_l[ iS4{i0}, iS5{i1} ]
= T0_g[ iS0{i0}, iS1{i1} ];
T4_g[ iS8{i0}, rS9{i1} ]
= reduction( T2_l[ iS4{i0}, iS5{i1} ], op = add, initial value = double(0), allreduce = false )
T5_g[ iS10{i0}, bS11{1} ]
= broadcast( T4_g[ iS8{i0}, rS9{i1} ] )
T6_l[ iS12{i0}, bS13{1} ]
= T5_g[ iS10{i0}, bS11{1} ];
T7_g[ iS14{i0}, bS15{32128} ] = expand( T6_l[ iS12{i0}, bS13{1} ], {i0, 32128} )
}
g{(transpose)
inputs:
T0_g[ iS0{i0}, iS1{i1} ] float
T1_g[ iS2{i3}, iS3{i4} ] float
T7_g[ iS14{i0}, bS15{32128} ] float
outputs:
T9_g[ iS18{i0}, iS19{i1} ] float
T3_g[ iS6{i3}, iS7{i4} ]
= expf(T1_g[ iS2{i3}, iS3{i4} ]);
T8_g[ iS16{i3}, iS17{i4} ]
= T3_g[ iS6{i3}, iS7{i4} ]
* T7_g[ iS14{i0}, bS15{32128} ];
T9_g[ iS18{i0}, iS19{i1} ]
= T0_g[ iS0{i0}, iS1{i1} ]
- T8_g[ iS16{i3}, iS17{i4} ];
}
} //Segmented_Fusion
Inputs fed to the original kernels are
tensor dtype: float sizes: (4096, 32128, ) stride: (32128, 1, ) pointer: 0x7f2ccc000000
tensor dtype: float sizes: (4096, 32128, ) stride: (32128, 1, ) pointer: 0x7f2c2c000000
My python example so far is:
import torch
t1 = torch.empty(4096, 32128, device="cuda")
t0 = torch.empty(4096, 32128, device="cuda")
def func(t1, t0):
t7 = t1.sum((1,), True)
t3 = t1.exp()
t8 = t3 * t7
t = t0 - t8
return t
from torch.fx.experimental.proxy_tensor import make_fx
from torch._decomp import get_decompositions
from torch._prims.executor import execute
from torch._prims.context import TorchRefsNvfuserCapabilityMode, TorchRefsMode
with TorchRefsNvfuserCapabilityMode():
gm = make_fx(func)(t1, t0)
print(gm.graph)
out = execute(gm, t1, t0, executor="nvfuser")
print(out[0].shape, out[0].stride())
it gives an un-segmented fusion:
Segmented_Fusion Dump: -- fusion segments:
Segmented_Fusion{
groups:
g{0, 1, 2, 3, 4, 5, 6, 7}
edges:
group details:
g{(persistent)
inputs:
T0_g[ iS0{i0}, iS1{i1} ] float
T1_g[ iS2{i3}, iS3{i4} ] float
outputs:
T9_g[ iS18{i3}, iS19{i4} ] float
T5_l[ iS10{i0}, iS11{i1} ]
= expf(T0_g[ iS0{i0}, iS1{i1} ]);
T2_l[ iS4{i0}, iS5{i1} ]
= T0_g[ iS0{i0}, iS1{i1} ];
T3_l[ iS6{i0}, rS7{i1} ]
= reduction( T2_l[ iS4{i0}, iS5{i1} ], op = add, initial value = double(0), allreduce = false )
T4_l[ iS8{i0}, bS9{1} ]
= broadcast( T3_l[ iS6{i0}, rS7{i1} ] )
T6_l[ iS12{i0}, bS13{1} ]
= T4_l[ iS8{i0}, bS9{1} ];
T7_l[ iS14{i0}, bS15{32128} ] = expand( T6_l[ iS12{i0}, bS13{1} ], {i0, 32128} )
T8_l[ iS16{i0}, iS17{i1} ]
= T5_l[ iS10{i0}, iS11{i1} ]
* T7_l[ iS14{i0}, bS15{32128} ];
T9_g[ iS18{i3}, iS19{i4} ]
= T1_g[ iS2{i3}, iS3{i4} ]
- T8_l[ iS16{i0}, iS17{i1} ];
}
} //Segmented_Fusion
Arguments for fusion1:
Inputs:
tensor dtype: float sizes: (4096, 32128, ) stride: (32128, 1, ) pointer: 0x7fda5c000000
tensor dtype: float sizes: (4096, 32128, ) stride: (32128, 1, ) pointer: 0x7fda3c000000
Note that the nvprim thing also produces wrong output... Since the graph looks fine, I'm uncertain where that issue is coming from.
My next step is to look at the primtorch traced fx graph and figure out if there's anything I missed.
@jjsjann123 Is this a new error after merging https://github.com/csarofeen/pytorch/pull/2035, or is it an error that already existed before that?
This errors out on both before and after merging #2035, so I guess the fix is just not clean. I think I'm really close to get a cpp repro now. :crossed_fingers:
@zasdfgbnm Here you go: https://github.com/csarofeen/pytorch/blob/9d35779e6d226220f721b409ab7cc824830349aa/torch/csrc/jit/codegen/cuda/test/test_gpu3.cpp#L6443-L6474
There's some funny bits with how segmentation works... I'm getting a little bit nervous of why I can't repro this without a segmented fusion. :cold_sweat:
Thanks for the repro! Will take a look ASAP.
I've pulled latest devel branch but am still getting this: Arguments for fusion3: Inputs: tensor dtype: float sizes: (294912, 1, ) stride: (1, 1, ) pointer: 000000070CF40000 tensor dtype: float sizes: (294912, 1, ) stride: (1, 1, ) pointer: 000000070CE20000 tensor dtype: float sizes: (294912, 1, ) stride: (4, 4, ) pointer: 000000070C9A0008 tensor dtype: float sizes: (294912, 1, ) stride: (4, 4, ) pointer: 000000070C9A0000 Outputs: Launch Parameters: BlockDim.x = 128, BlockDim.y = -1, BlockDim.z = -1, GridDim.x = -1, GridDim.y = -1, GridDim.z = -1, Smem Size = 0 The following operation failed in the TorchScript interpreter. Traceback of TorchScript (most recent call last): RuntimeError: stride == cur_contig_stride || size == 1 || is_expanded_broadcasting || (still_rightmost && stride == 1) || (!still_rightmost && stride % word_size == 0) INTERNAL ASSERT FAILED at "C:\work\pytorch\torch\csrc\jit\codegen\cuda\executor_utils.cpp":625, please report a bug to PyTorch. Vectorization of T2_g[ iS125{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] 1 ), 2) ), 1) ), 128) )}, iS124{1}, iS122{2}, iS126{128} ] with word size 2 not possible due to invalid stride. Domain: iS125{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] 1 ), 2) ), 1) ), 128) )}, stride: 4 torchscript.txt
Same RuntimeError at detectron2/tests/test_export_torchscript.py
RuntimeError: The following operation failed in the TorchScript interpreter.
Traceback of TorchScript (most recent call last):
RuntimeError: stride == cur_contig_stride || size == 1 || (still_rightmost && stride == 1) || (!still_rightmost && stride % word_size == 0) INTERNAL ASSERT FAILED at "../torch/csrc/jit/codegen/cuda/executor_utils.cpp":623, please report a bug to PyTorch. Vectorization of T3_g[ iS97{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * 1 ), 2) ), 1) ), 128) )}, iS96{1}, iS94{2}, iS98{128} ] with word size 2 not possible due to invalid stride. Domain: iS97{( ceilDiv(( ceilDiv(( ceilDiv(( T0.size[0] * 1 ), 2) ), 1) ), 128) )}, stride: 4
🐛 Describe the bug
Fusion fails with fallback path with channels_last + amp + batch_size > 1
Steps to reproduce:
PYTORCH_NVFUSER_DISABLE_FALLBACK=1 python benchmark.py --bench train --model rexnet_100 --img-size 224 -b 2 --amp --torchscript --channels-last --no-retry
Failing networks:
General Error:
Versions