Open Happyholic1203 opened 7 months ago
I think the exact same thing that happened in #117549 and #116769 is happening again in torch.nn.functional.linear
, which is implemented in aten/src/ATen/native/mps/operations/Linear.mm:_mps_linear
.
When the input size is large enough (see my above code to reproduce), MPS matmul returns silently incorrect results.
My idea is to migrate the patch for #117549 to _mps_linear
:
if (use_metal_mm(self, other, output)) {
return do_metal_mm(self, other, output);
}
That does improve the correctness a bit (roughly from stddev 2.0 to 1.4, which should be 0), but it's still not producing results close to CPU results.
I'm more than happy to submit a patch, but this is my first time debugging pytorch backend and it's taking quite a bit of time.
Any ideas, suggestions, or even directions would be very much appreciated.
cc @kulinseth @malfet
Anyone wants to try it again? Because as of 2.4 the above-mentioned example works fine for me
Bug is still there(( MacOS: 13.6.7 MacBook Air M1 16GB PyTorch version: 2.5.0.dev20240806
Used minimal code to reproduce from above.
Used workaround, but it appears to be that torch.matmul()
on MPS is >4x slower than just running nn.Linear on CPU altogether:
linear_out_cpu time: 0.26 sec
linear_out_mps time: 0.41 sec
linear_out_mps_split_and_stack time: 1.59 sec
matmul_out_cpu time: 1.10 sec
matmul_out_mps time: 1.14 sec
linear_out_cpu.allclose(matmul_out_cpu)=True
linear_out_cpu.allclose(matmul_out_mps)=True
=== BUG ===
linear_out_cpu.allclose(linear_out_mps)=False
(linear_out_cpu - linear_out_mps).std()=tensor(1.3945)
=== PROOF (& WORKAROUND) ===
linear_out_cpu.allclose(linear_out_mps_split_and_stack)=True
Code used:
import time
import torch
import torch.nn.functional as F
from torch import nn
torch.manual_seed(1234)
w = torch.randn(50304, 1)
x = torch.randn(9, 1024, 1)
# these 5 operations should be equivalent
start_time = time.time()
linear_out_cpu = F.linear(x, w, None)
print("linear_out_cpu time: {:.2f} sec".format(time.time() - start_time))
start_time = time.time()
linear_out_mps = F.linear(x.to('mps'), w.to('mps'), None).to('cpu')
print("linear_out_mps time: {:.2f} sec".format(time.time() - start_time))
start_time = time.time()
linear_out_mps_split_and_stack = torch.stack([F.linear(xb.to('mps'), w.to('mps')) for xb in x]).to('cpu')
print("linear_out_mps_split_and_stack time: {:.2f} sec".format(time.time() - start_time))
start_time = time.time()
matmul_out_cpu = x@w.T
print("matmul_out_cpu time: {:.2f} sec".format(time.time() - start_time))
start_time = time.time()
matmul_out_mps = (x.to('mps')@w.T.to('mps')).to('cpu')
print("matmul_out_mps time: {:.2f} sec".format(time.time() - start_time))
# All True, as expected
print(f'{linear_out_cpu.allclose(matmul_out_cpu)=}')
print(f'{linear_out_cpu.allclose(matmul_out_mps)=}')
print('=== BUG ===')
# BUG ==> Expect this to be True, got False instead
print(f'{linear_out_cpu.allclose(linear_out_mps)=}')
print(f'{(linear_out_cpu - linear_out_mps).std()=}')
print('=== PROOF (& WORKAROUND) ===')
# Splitting the batches can "workaround" the above BUG
print(f'{linear_out_cpu.allclose(linear_out_mps_split_and_stack)=}')
Full system specs:
PyTorch version: 2.5.0.dev20240806
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.6.7 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.11.4 (main, Jul 21 2023, 12:17:16) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-13.6.7-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1
Versions of relevant libraries:
[pip3] efficientnet_pytorch==0.7.1
[pip3] numpy==1.26.4
[pip3] pytorch-lightning==2.4.0
[pip3] pytorch-ranger==0.1.1
[pip3] rotary-embedding-torch==0.3.5
[pip3] segmentation-models-pytorch==0.3.3
[pip3] torch==2.5.0.dev20240806
[pip3] torch-audiomentations==0.11.1
[pip3] torch-optimizer==0.1.0
[pip3] torch-pitch-shift==1.2.4
[pip3] torch-stoi==0.2.1
[pip3] torchaudio==2.4.0.dev20240806
[pip3] torchmetrics==0.11.4
[pip3] torchseg==0.0.1a1
[pip3] torchvision==0.20.0.dev20240806
[conda] No relevant packages
I'm unable to reproduce this issue.
import torch
import torch.nn.functional as F
torch.manual_seed(1234)
w = torch.randn(50304, 1)
x = torch.randn(9, 1024, 1)
# these 5 operations should be equivalent
linear_out_cpu = F.linear(x, w, None)
linear_out_mps = F.linear(x.to('mps'), w.to('mps'), None).to('cpu')
linear_out_mps_split_and_stack = torch.stack([F.linear(xb.to('mps'), w.to('mps')) for xb in x]).to('cpu')
matmul_out_cpu = x@w.T
matmul_out_mps = (x.to('mps')@w.T.to('mps')).to('cpu')
# All True, as expected
print(f'{linear_out_cpu.allclose(matmul_out_cpu)=}')
print(f'{linear_out_cpu.allclose(matmul_out_mps)=}')
print('=== BUG ===')
# BUG ==> Expect this to be True, got False instead
print(f'{linear_out_cpu.allclose(linear_out_mps)=}')
print(f'{(linear_out_cpu - linear_out_mps).std()=}')
print('=== PROOF (& WORKAROUND) ===')
# Splitting the batches can "workaround" the above BUG
print(f'{linear_out_cpu.allclose(linear_out_mps_split_and_stack)=}')
yields
linear_out_cpu.allclose(matmul_out_cpu)=True
linear_out_cpu.allclose(matmul_out_mps)=True
=== BUG ===
linear_out_cpu.allclose(linear_out_mps)=True
(linear_out_cpu - linear_out_mps).std()=tensor(0.)
=== PROOF (& WORKAROUND) ===
linear_out_cpu.allclose(linear_out_mps_split_and_stack)=True
PyTorch version: 2.5.0a0+git648fc6c
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 14.6.1 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.3.9.4)
CMake version: version 3.30.1
Libc version: N/A
Python version: 3.12.4 | packaged by conda-forge | (main, Jun 17 2024, 10:13:44) [Clang 16.0.6 ] (64-bit runtime)
Python platform: macOS-14.6.1-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M3 Max
Versions of relevant libraries:
[pip3] flake8==6.1.0
[pip3] flake8-bugbear==23.3.23
[pip3] flake8-comprehensions==3.15.0
[pip3] flake8-executable==2.1.3
[pip3] flake8-logging-format==0.9.0
[pip3] flake8-pyi==23.3.1
[pip3] flake8-simplify==0.19.3
[pip3] mypy==1.10.0
[pip3] mypy-extensions==1.0.0
[pip3] numpy==2.0.1
[pip3] optree==0.12.1
[pip3] torch==2.5.0a0+git8d40458
[pip3] torch-tb-profiler==0.4.3
[pip3] torchvision==0.20.0a0+0d80848
[conda] numpy 2.0.1 pypi_0 pypi
[conda] optree 0.12.1 pypi_0 pypi
[conda] torch 2.5.0a0+git8d40458 dev_0 <develop>
[conda] torch-tb-profiler 0.4.3 pypi_0 pypi
[conda] torchfix 0.4.0 pypi_0 pypi
[conda] torchvision 0.20.0a0+0d80848 dev_0 <develop>
@hvaara , maybe it's because of MacOS version. Do you have a Mac with 13.x MacOS?
@dj-nuo That could be the case, although OP was reporting OS: macOS 14.2.1 (arm64)
when they first experienced it. My comment was meant more as an additional data point, and also to remind myself what my finding was.
What I intend to do is to add a regression test for it in PyTorch. That way it'll be run for all future versions as well :)
Are you on macOS 13.6.9
? Are you still able to reproduce with latest PyTorch?
@hvaara , just did a reproduction on an empty environment with latest PyTorch (2.5.0.dev20240822 instead of 2.5.0.dev20240806 previously):
linear_out_cpu.allclose(matmul_out_cpu)=True
linear_out_cpu.allclose(matmul_out_mps)=True
=== BUG ===
linear_out_cpu.allclose(linear_out_mps)=False
(linear_out_cpu - linear_out_mps).std()=tensor(1.3945)
=== PROOF (& WORKAROUND) ===
linear_out_cpu.allclose(linear_out_mps_split_and_stack)=True
code used:
import torch
import torch.nn.functional as F
torch.manual_seed(1234)
w = torch.randn(50304, 1)
x = torch.randn(9, 1024, 1)
# these 5 operations should be equivalent
linear_out_cpu = F.linear(x, w, None)
linear_out_mps = F.linear(x.to('mps'), w.to('mps'), None).to('cpu')
linear_out_mps_split_and_stack = torch.stack([F.linear(xb.to('mps'), w.to('mps')) for xb in x]).to('cpu')
matmul_out_cpu = x@w.T
matmul_out_mps = (x.to('mps')@w.T.to('mps')).to('cpu')
# All True, as expected
print(f'{linear_out_cpu.allclose(matmul_out_cpu)=}')
print(f'{linear_out_cpu.allclose(matmul_out_mps)=}')
print('=== BUG ===')
# BUG ==> Expect this to be True, got False instead
print(f'{linear_out_cpu.allclose(linear_out_mps)=}')
print(f'{(linear_out_cpu - linear_out_mps).std()=}')
print('=== PROOF (& WORKAROUND) ===')
# Splitting the batches can "workaround" the above BUG
print(f'{linear_out_cpu.allclose(linear_out_mps_split_and_stack)=}')
environment:
PyTorch version: 2.5.0.dev20240822
Is debug build: False
CUDA used to build PyTorch: None
ROCM used to build PyTorch: N/A
OS: macOS 13.6.7 (arm64)
GCC version: Could not collect
Clang version: 15.0.0 (clang-1500.1.0.2.5)
CMake version: Could not collect
Libc version: N/A
Python version: 3.11.4 (main, Jul 21 2023, 12:17:16) [Clang 14.0.3 (clang-1403.0.22.14.1)] (64-bit runtime)
Python platform: macOS-13.6.7-arm64-arm-64bit
Is CUDA available: False
CUDA runtime version: No CUDA
CUDA_MODULE_LOADING set to: N/A
GPU models and configuration: No CUDA
Nvidia driver version: No CUDA
cuDNN version: No CUDA
HIP runtime version: N/A
MIOpen runtime version: N/A
Is XNNPACK available: True
CPU:
Apple M1
Versions of relevant libraries:
[pip3] numpy==1.26.4
[pip3] torch==2.5.0.dev20240822
[pip3] torchaudio==2.4.0.dev20240822
[pip3] torchvision==0.20.0.dev20240822
[conda] No relevant packages
🐛 Describe the bug
Running
F.linear
in MPS produces non-negligible (stddev > 1) error when input size is large enough. (see the following code snippet to reproduce)However, when we use
torch.stack([F.linear(...)])
to perform effectively the same operation in MPS, there is no error, proving that there is side effect when the input toF.linear
is large enough. (In this case, using a batch size of9
causes the error to surface, while batch size of1
doesn't)Minimal code to reproduce:
I think #109457 is related, and possibly #117826
Versions
cc @kulinseth @albanD @malfet @DenisVieriu97 @razarmehr