pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
83.56k stars 22.54k forks source link

Silent error in torch.Tensor.sum on mps #132732

Open nisheethlahoti opened 2 months ago

nisheethlahoti commented 2 months ago

🐛 Describe the bug

If you run the code below:

import torch

x = torch.ones(1, 2, 1, 1, 1, device="mps")
print(x.sum(-3))

The expected output is

tensor([[[[1.]],

         [[1.]]]], device='mps:0')

While the actual output is

tensor([[[[2.]],

         [[0.]]]], device='mps:0')

On CPU we get the expected output. Also, relatedly, trying to do x.sum(-5) works just fine on CPU, but throws IndexError on MPS.

Versions

PyTorch version: 2.4.0
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A

OS: macOS 14.6 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: Could not collect
Libc version: N/A

Python version: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.6-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True

CPU:
Apple M1 Max

Versions of relevant libraries:
[pip3] numpy==2.0.1
[pip3] pytorch-lightning==2.3.0
[pip3] torch==2.4.0
[pip3] torchaudio==2.4.0
[pip3] torchmetrics==1.4.0.post0
[pip3] torchvision==0.19.0
[conda] Could not collect

FYI, I'm using micromamba, which I think the collect_env script should ideally club with conda, but doesn't.

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

jhavukainen commented 2 months ago

Thanks for reporting this @nisheethlahoti! I can reproduce this locally as well.

nisheethlahoti commented 2 months ago

My first guess on the cause of the error was that negative axes were being converted incorrectly somewhere using ndim - 1 + dim instead of ndim + dim, but this error happens only when dim <= -3, and it seems unlikely we're handling dim=-1 or -2 separately from everything else.