pytorch / pytorch

Tensors and Dynamic neural networks in Python with strong GPU acceleration
https://pytorch.org
Other
82.07k stars 22.05k forks source link

MPS backend leaks memory when input sizes vary #132596

Open llllvvuu opened 1 month ago

llllvvuu commented 1 month ago

🐛 Describe the bug

Possibly similar to an old issue with the CPU backend: https://github.com/pytorch/pytorch/issues/27971 https://github.com/pytorch/pytorch/issues/32037

In my case both CPU and CUDA work fine, and only MPS has the issue. Is there anything similar to LRU_CACHE_CAPACITY but for MPS?

import resource
import torch

device = torch.device("mps")  # no leak on "cpu" or "cuda"
model = torch.nn.Linear(32, 32, device=device)
for i in range(1, 8193):
    _ = model(torch.randn(8192, 32, device=device))  # memory usage fixed at ~176MB
    print(f"(good) {i} / 8192, Memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}")
for i in range(1, 8193):
    _ = model(torch.randn(i, 32, device=device))  # memory growth up to ~3GB at ~1500 iters
    print(f"(bad) {i} / 8192, Memory: {resource.getrusage(resource.RUSAGE_SELF).ru_maxrss}")

Versions

PyTorch version: 2.4.0 Is debug build: False CUDA used to build PyTorch: None ROCM used to build PyTorch: N/A

OS: macOS 14.2 (arm64) GCC version: Could not collect Clang version: 15.0.0 (clang-1500.3.9.4) CMake version: version 3.29.6 Libc version: N/A

Python version: 3.12.3 (main, Apr 15 2024, 17:43:11) [Clang 17.0.6 ] (64-bit runtime) Python platform: macOS-14.2-arm64-arm-64bit Is CUDA available: False CUDA runtime version: No CUDA CUDA_MODULE_LOADING set to: N/A GPU models and configuration: No CUDA Nvidia driver version: No CUDA cuDNN version: No CUDA HIP runtime version: N/A MIOpen runtime version: N/A Is XNNPACK available: True

CPU: Apple M1

Versions of relevant libraries: [pip3] torch==2.4.0 [conda] Could not collect

cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen

tringwald commented 1 month ago

From what I've seen in this file, there does not seem to be a way to limit the cache size right now.

skotapati commented 1 month ago

Hi @llllvvuu, thank you for reporting this issue and providing the repro. I've confirmed the memory leaks are occurring, stay tuned for further updates