Closed soumith closed 2 years ago
This is not a CUDAGraphs issue.
Even with CUDAGraphs disabled, it runs into the same error, but with a slightly different stack-trace:
$ python main.py --gpu 0 /home/soumith/dataset/imagenet
/home/soumith/code/vision/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension:
warn(f"Failed to load image Python extension: {e}")
/home/soumith/code/examples/imagenet/main.py:102: UserWarning: You have chosen a specific GPU. This will completely disable data parallelism.
warnings.warn('You have chosen a specific GPU. This will completely '
Use GPU: 0 for training
=> creating model 'resnet18'
make_fallback(aten.unfold): a decomposition exists, we should switch to it
make_fallback(aten.unfold_backward): a decomposition exists, we should switch to it
Traceback (most recent call last):
File "/home/soumith/code/examples/imagenet/main.py", line 515, in <module>
main()
File "/home/soumith/code/examples/imagenet/main.py", line 123, in main
main_worker(args.gpu, ngpus_per_node, args)
File "/home/soumith/code/examples/imagenet/main.py", line 282, in main_worker
train(train_loader, model, criterion, optimizer, epoch, device, args)
File "/home/soumith/code/examples/imagenet/main.py", line 329, in train
output = model(images)
File "/home/soumith/code/pytorch/torch/_dynamo/eval_frame.py", line 137, in __call__
return self.forward(*args, **kwargs)
File "/home/soumith/code/pytorch/torch/_dynamo/eval_frame.py", line 134, in forward
return optimized_forward(*args, **kwargs)
File "/home/soumith/code/pytorch/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/home/soumith/code/vision/torchvision/models/resnet.py", line 284, in forward
def forward(self, x: Tensor) -> Tensor:
File "/home/soumith/code/pytorch/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 856, in forward
return compiled_f(
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 847, in new_func
return compiled_fn(args)
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 230, in g
return f(*args)
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 475, in compiled_function
return CompiledFunction.apply(*remove_dupe_args(args))
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 442, in forward
fw_outs = call_func_with_args(
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 255, in call_func_with_args
out = normalize_as_list(f(args))
File "/home/soumith/code/pytorch/torch/_inductor/compile_fx.py", line 179, in run
return model(new_inputs_to_cuda)
File "/tmp/torchinductor_soumith/yz/cyzv2xzkmvwv33lxnmvd7lvgj4sq7l75r2jp76hekwqzumu2ovoo.py", line 1791, in call
assert_size_stride(buf56, (256, 128, 28, 28), (100352, 1, 3584, 128))
AssertionError: expected size 128==128, stride 784==1 at dim=1
okay, I bisected the issue to https://github.com/pytorch/pytorch/pull/87049 cc: @anijain2305
Reverting that commit fixes things and the example runs correctly.
Minifier doesn't do anything:
$ TORCHDYNAMO_REPRO_AFTER="aot" python main.py --gpu 0 /home/soumith/dataset/imagenet
/home/soumith/code/vision/torchvision/io/image.py:13: UserWarning: Failed to load image Python extension:
warn(f"Failed to load image Python extension: {e}")
/home/soumith/code/examples/imagenet/main.py:102: UserWarning: You have chosen a specific GPU. This will completely disable data parallelism.
warnings.warn('You have chosen a specific GPU. This will completely '
Use GPU: 0 for training
=> creating model 'resnet18'
make_fallback(aten.unfold): a decomposition exists, we should switch to it
make_fallback(aten.unfold_backward): a decomposition exists, we should switch to it
Writing minified repro to /tmp/minifier_soumith/minifier_launcher.py
Copying minified repro from /tmp/minifier_soumith/minifier_launcher.py to /home/soumith/code/pytorch/minifier_launcher.py for convenience
Traceback (most recent call last):
File "/home/soumith/code/examples/imagenet/main.py", line 515, in <module>
main()
File "/home/soumith/code/examples/imagenet/main.py", line 123, in main
main_worker(args.gpu, ngpus_per_node, args)
File "/home/soumith/code/examples/imagenet/main.py", line 282, in main_worker
train(train_loader, model, criterion, optimizer, epoch, device, args)
File "/home/soumith/code/examples/imagenet/main.py", line 329, in train
output = model(images)
File "/home/soumith/code/pytorch/torch/_dynamo/eval_frame.py", line 137, in __call__
return self.forward(*args, **kwargs)
File "/home/soumith/code/pytorch/torch/_dynamo/eval_frame.py", line 134, in forward
return optimized_forward(*args, **kwargs)
File "/home/soumith/code/pytorch/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/home/soumith/code/vision/torchvision/models/resnet.py", line 284, in forward
def forward(self, x: Tensor) -> Tensor:
File "/home/soumith/code/pytorch/torch/_dynamo/eval_frame.py", line 157, in _fn
return fn(*args, **kwargs)
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 856, in forward
return compiled_f(
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 847, in new_func
return compiled_fn(args)
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 230, in g
return f(*args)
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 475, in compiled_function
return CompiledFunction.apply(*remove_dupe_args(args))
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 442, in forward
fw_outs = call_func_with_args(
File "/home/soumith/code/pytorch/functorch/_src/aot_autograd.py", line 255, in call_func_with_args
out = normalize_as_list(f(args))
File "/home/soumith/code/pytorch/torch/_dynamo/debug_utils.py", line 444, in deferred_for_real_inputs
raise e
File "/home/soumith/code/pytorch/torch/_dynamo/debug_utils.py", line 430, in deferred_for_real_inputs
return compiled_fn(real_inputs)
File "/home/soumith/code/pytorch/torch/_inductor/compile_fx.py", line 179, in run
return model(new_inputs_to_cuda)
File "/home/soumith/code/pytorch/torch/_inductor/compile_fx.py", line 196, in run
compiled_fn = cudagraphify_impl(model, new_inputs, static_input_idxs)
File "/home/soumith/code/pytorch/torch/_inductor/compile_fx.py", line 254, in cudagraphify_impl
model(list(static_inputs))
File "/tmp/torchinductor_soumith/yz/cyzv2xzkmvwv33lxnmvd7lvgj4sq7l75r2jp76hekwqzumu2ovoo.py", line 1791, in call
assert_size_stride(buf56, (256, 128, 28, 28), (100352, 1, 3584, 128))
AssertionError: expected size 128==128, stride 784==1 at dim=1
actually, Minifier works -- I didn't know that I should run the minifier_launcher.py
.
Here's the minified repro:
import torch
from torch import tensor, device
import torch.fx as fx
from torch._dynamo.testing import rand_strided
from math import inf
from torch.fx.experimental.proxy_tensor import make_fx
# torch version: 1.14.0a0+git240bba7
# torch cuda version: 11.6
# torch git version: 240bba7ac85b6163c7c75a168019cd0b6d1c6aa0
# CUDA Info:
# nvcc: NVIDIA (R) Cuda compiler driver
# Copyright (c) 2005-2022 NVIDIA Corporation
# Built on Tue_Mar__8_18:18:20_PST_2022
# Cuda compilation tools, release 11.6, V11.6.124
# Build cuda_11.6.r11.6/compiler.31057947_0
# GPU Hardware Info:
# NVIDIA GeForce RTX 3090 : 1
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
def forward(self, arg21_1, relu_4):
convolution_7 = torch.ops.aten.convolution.default(relu_4, arg21_1, None, [2, 2], [0, 0], [1, 1], False, [0, 0], 1); relu_4 = arg21_1 = None
return (convolution_7,)
args = [((128, 64, 1, 1), (64, 1, 1, 1), torch.float32, 'cuda'), ((256, 64, 56, 56), (200704, 3136, 56, 1), torch.float32, 'cuda')]
args = [rand_strided(sh, st, dt, dev) for (sh, st, dt, dev) in args]
mod = make_fx(Repro().to(device="cuda"))(*args)
from torch._inductor.compile_fx import compile_fx_inner
from torch._dynamo.debug_utils import same_two_models
compiled = compile_fx_inner(mod, args)
compiled(args)
I created another minified repro -- same error:
$ cat repro.py
from math import inf
import torch
from torch import tensor, device
import torch.fx as fx
import torch._dynamo
from torch._dynamo.testing import rand_strided
from torch._dynamo.debug_utils import run_fwd_maybe_bwd
from torch._dynamo.debug_utils import same_two_models
args = [((256, 64, 56, 56), (200704, 3136, 56, 1), torch.float32, 'cuda', True)]
args = [rand_strided(sh, st, dt, dev).requires_grad_(rg) for (sh, st, dt, dev, rg) in args]
from torch.nn import *
class Repro(torch.nn.Module):
def __init__(self):
super().__init__()
self.self_self_layer2_0_downsample_0 = Conv2d(64, 128, kernel_size=(1, 1), stride=(2, 2), bias=False)
def forward(self, self_self_layer1_1_relu_1):
self_self_layer2_0_downsample_0 = self.self_self_layer2_0_downsample_0(self_self_layer1_1_relu_1); self_self_layer1_1_relu_1 = None
return (self_self_layer2_0_downsample_0,)
mod = Repro().cuda()
opt_mod = torch._dynamo.optimize("inductor")(mod)
with torch.cuda.amp.autocast(enabled=False):
ref = run_fwd_maybe_bwd(mod, args)
res = run_fwd_maybe_bwd(opt_mod, args)
This is the generated inductor code:
from ctypes import c_void_p, c_long
import torch
import random
from torch import empty_strided, as_strided, device
from torch._inductor.codecache import AsyncCompile
aten = torch.ops.aten
assert_size_stride = torch._C._dynamo.guards.assert_size_stride
async_compile = AsyncCompile()
import triton
import triton.language as tl
from torch._inductor.triton_ops.autotune import grid
from torch._C import _cuda_getCurrentRawStream as get_cuda_stream
kernel0 = async_compile.triton('''
import triton
import triton.language as tl
from torch._inductor.ir import ReductionHint
from torch._inductor.triton_ops.autotune import pointwise
from torch._inductor.utils import instance_descriptor
@pointwise(size_hints=[16384, 4096], filename=__file__, meta={'signature': {0: '*fp32', 1: '*fp32', 2: 'i32', 3: 'i32', 4: 'i32'}, 'device': 0, 'configs': [instance_descriptor(divisible_by_16=(0, 1), equal_to_1=())], 'constants': {}})
@triton.jit
def kernel(in_ptr0, out_ptr0, ks0, xnumel, ynumel, XBLOCK : tl.constexpr, YBLOCK : tl.constexpr):
xoffset = tl.program_id(0) * XBLOCK
xindex = xoffset + tl.reshape(tl.arange(0, XBLOCK), [XBLOCK, 1])
xmask = xindex < xnumel
yoffset = tl.program_id(1) * YBLOCK
yindex = yoffset + tl.reshape(tl.arange(0, YBLOCK), [1, YBLOCK])
ymask = yindex < ynumel
x3 = xindex
y2 = yindex
x0 = xindex % 64
x1 = (xindex // 64)
tmp0 = tl.load(in_ptr0 + (y2 + (x3*(ks0*ks0))), xmask & ymask)
tl.store(out_ptr0 + (x0 + (64*y2) + (64*x1*(ks0*ks0)) + tl.zeros([XBLOCK, YBLOCK], tl.int32)), tmp0, xmask & ymask)
''')
async_compile.wait(globals())
del async_compile
def call(args):
primals_1, primals_2 = args
args.clear()
primals_1_size = primals_1.size()
s0 = primals_1_size[0]
s1 = primals_1_size[1]
primals_2_size = primals_2.size()
s2 = primals_2_size[0]
s3 = primals_2_size[2]
buf1 = empty_strided((s2, 64, s3, s3), (64*(s3*s3), 1, 64*s3, 64), device='cuda', dtype=torch.float32)
kernel0_xnumel = 64*s2
kernel0_ynumel = s3*s3
stream0 = get_cuda_stream(0)
kernel0.run(primals_2, buf1, s3, kernel0_xnumel, kernel0_ynumel, grid=grid(kernel0_xnumel, kernel0_ynumel), stream=stream0)
buf0 = aten.convolution(buf1, primals_1, None, (2, 2), (0, 0), (1, 1), False, (0, 0), 1)
assert_size_stride(buf0, (s2, 128, 28, 28), (100352, 1, 3584, 128))
return (buf0, primals_1, primals_2, )
if __name__ == "__main__":
from torch._dynamo.testing import rand_strided
from torch._inductor.utils import print_performance
primals_1 = rand_strided((128, 64, 1, 1), (64, 1, 1, 1), device='cuda', dtype=torch.float32)
primals_2 = rand_strided((256, 64, 56, 56), (200704, 3136, 56, 1), device='cuda', dtype=torch.float32)
print_performance(lambda: call([primals_1, primals_2]))
it fails at the assert_size_stride
call
I see the bug. Here:
buf1
is size, stride of torch.Size([256, 64, 56, 56]) (200704, 1, 3584, 64)
primals_1
is size, stride of torch.Size([128, 64, 1, 1]) (64, 1, 1, 1)
buf0
turns out is not in channels last, and is regular contiguous output: torch.Size([256, 128, 28, 28]) (100352, 784, 28, 1)
So the convolution in my install is not respecting channels last
okay, so I think I figured it out. My install doesn't have any CuDNN.
print(torch.__config__.show())
PyTorch built with:
- GCC 9.4
- C++ Version: 201402
- Intel(R) oneAPI Math Kernel Library Version 2022.1-Product Build 20220311 for Intel(R) 64 architecture applications
- Intel(R) MKL-DNN v2.6.0 (Git Hash 52b5f107dd9cf10910aaa19cb47f3abf9b349815)
- OpenMP 201511 (a.k.a. OpenMP 4.5)
- LAPACK is enabled (usually provided by MKL)
- NNPACK is enabled
- CPU capability usage: AVX2
- CUDA Runtime 11.6
- NVCC architecture flags: -gencode;arch=compute_86,code=sm_86
- Magma 2.6.1
- Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.6, CXX_COMPILER=/usr/bin/c++, CXX_FLAGS= -Wno-deprecated -fvisibility-inlines-hidden -DUSE_PTHREADPOOL -fopenmp -DNDEBUG -DUSE_KINETO -DUSE_FBGEMM -DUSE_QNNPACK -DUSE_PYTORCH_QNNPACK -DUSE_XNNPACK -DSYMBOLICATE_MOBILE_DEBUG_HANDLE -DEDGE_PROFILER_USE_KINETO -O2 -fPIC -Wno-narrowing -Wall -Wextra -Werror=return-type -Werror=non-virtual-dtor -Wno-missing-field-initializers -Wno-type-limits -Wno-array-bounds -Wno-unknown-pragmas -Wunused-local-typedefs -Wno-unused-parameter -Wno-unused-function -Wno-unused-result -Wno-strict-overflow -Wno-strict-aliasing -Wno-error=deprecated-declarations -Wno-stringop-overflow -Wno-psabi -Wno-error=pedantic -Wno-error=redundant-decls -Wno-error=old-style-cast -fdiagnostics-color=always -faligned-new -Wno-unused-but-set-variable -Wno-maybe-uninitialized -fno-math-errno -fno-trapping-math -Werror=format -Werror=cast-function-type -Wno-stringop-overflow, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.14.0, USE_CUDA=ON, USE_CUDNN=OFF, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG=OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=ON, USE_NNPACK=ON, USE_OPENMP=ON, USE_ROCM=OFF,
In this case, channels_last is not respected. Your PR doesn't check for this case, I think
See https://github.com/pytorch/torchdynamo/issues/1687 for original context.
Now it's failing on latest pytorch
master
, first I ran into a parallel compile issue for which I put up a patch: https://github.com/pytorch/pytorch/pull/87174After that applied, it still fails with a different CUDAGraphs error.