Closed DarkMnDragon closed 1 year ago
Thanks for the report, here's a possible repro of that error:
a = torch.tensor([1.], requires_grad=True)
c = a.clone()
v = c[:]
b = torch.tensor(1., requires_grad=True)
class InplaceFunc(torch.autograd.Function):
@staticmethod
def forward(ctx, x, other):
ctx.mark_dirty(x)
return x.mul_(2)
@staticmethod
def backward(ctx, grad):
return grad, None
out = InplaceFunc.apply(v, b)
torch.autograd.grad(out, inputs=(a, b))
What's happening here is that
One thing we could do is just remove the error, since we should be treating undefined tensors as zeros anyway. If we don't wish to allow this, we should at least be making the internal assert a normal torch check.
properly treating undefined as zero gradients sounds fair to me. I guess it just happens that our inplace ops (that get put inside CopySlices) are well behaved and never do this. But custom Function can so we should support it.
🐛 Describe the bug
The Code is:
Versions
PyTorch version: 2.0.1 Is debug build: False CUDA used to build PyTorch: 11.8 ROCM used to build PyTorch: N/A
OS: Microsoft Windows 11 Home China GCC version: Could not collect Clang version: Could not collect CMake version: Could not collect Libc version: N/A
Python version: 3.9.16 (main, May 17 2023, 17:49:16) [MSC v.1916 64 bit (AMD64)] (64-bit runtime) Python platform: Windows-10-10.0.22621-SP0 Is CUDA available: True CUDA runtime version: 12.1.105 CUDA_MODULE_LOADING set to: LAZY GPU models and configuration: GPU 0: NVIDIA GeForce RTX 3080 Ti Laptop GPU Nvidia driver version: 536.99 cuDNN version: Could not collect HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True
CPU: Architecture=9 CurrentClockSpeed=2000 DeviceID=CPU0 Family=198 L2CacheSize=14336 L2CacheSpeed= Manufacturer=GenuineIntel MaxClockSpeed=2000 Name=12th Gen Intel(R) Core(TM) i7-12800HX ProcessorType=3 Revision=
Versions of relevant libraries: [pip3] gpytorch==1.11 [pip3] numpy==1.25.0 [pip3] torch==2.0.1 [pip3] torch-fidelity==0.3.0 [pip3] torchaudio==2.0.2 [pip3] torchvision==0.15.2 [conda] blas 1.0 mkl [conda] gpytorch 1.11 pypi_0 pypi [conda] mkl 2023.1.0 h8bd8f75_46356 [conda] mkl-service 2.4.0 py39h2bbff1b_1 [conda] mkl_fft 1.3.6 py39hf11a4ad_1 [conda] mkl_random 1.2.2 py39hf11a4ad_1 [conda] numpy 1.25.0 py39h055cbcc_0 [conda] numpy-base 1.25.0 py39h65a83cf_0 [conda] pytorch 2.0.1 py3.9_cuda11.8_cudnn8_0 pytorch [conda] pytorch-cuda 11.8 h24eeafa_5 pytorch [conda] pytorch-mutex 1.0 cuda pytorch [conda] torch-fidelity 0.3.0 pypi_0 pypi [conda] torchaudio 2.0.2 pypi_0 pypi [conda] torchvision 0.15.2 pypi_0 pypi
cc @ezyang @gchanan @zou3519 @albanD @gqchen @pearu @nikitaved @soulitzer @Lezcano @Varal7