mberr / torch-max-mem

Decorators for maximizing memory utilization with PyTorch & CUDA
https://torch-max-mem.readthedocs.io/en/latest/
MIT License
14 stars 0 forks source link

Support for apple MPS #14

Open cthoyt opened 1 year ago

cthoyt commented 1 year ago

the improvements to pykeen in https://github.com/pykeen/pykeen/pull/975 don't work now that we have totally externalized AMO to this package - can we think about how to make sure that the CPU warnings also apply to MPS, since it has shared memory and somehow can get out of hand in the same way?

reminder I have an apple system with MPS so am happy to test

mberr commented 1 year ago

@cthoyt , if you have time, you could check #15 with the example from https://github.com/mberr/torch-max-mem#-getting-started

mberr commented 1 year ago

Hang on, the current status does explicitly not warn for mps by default 😅

cthoyt commented 1 year ago

I ran

import torch
from torch_max_mem import maximize_memory_utilization

@maximize_memory_utilization()
def knn(x, y, batch_size, k: int = 3):
    return torch.cat(
        [
            torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
            for start in range(0, x.shape[0], batch_size)
        ],
        dim=0,
    )

x = torch.rand(100, 100, device="mps")
y = torch.rand(200, 100, device="mps")
z = knn(x, y, batch_size=x.shape[0])
print(z)

Results:

Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
tensor([[ 2, 17, 58, 19, 29, 11, 72, 89, 94, 58, 17, 81, 41, 24, 64, 37, 38, 77,
         76, 33, 17, 22,  7, 37,  7, 19, 72, 11, 64, 15, 15,  6, 49, 77, 22, 69,
          3, 81, 36, 89,  6, 50, 19, 82, 45, 81,  4, 37, 37, 89, 39, 30,  2, 37,
         89, 63, 48, 19, 30, 37, 93, 13, 36, 56, 53, 35, 45, 20, 88, 17, 51, 49,
         37, 89, 19, 39,  1, 71, 28, 14, 89, 41, 13, 75, 45, 83, 72, 20, 46,  3,
         53, 49, 39,  4, 85, 19, 28, 19, 88, 81, 19, 70, 56, 13, 19, 19, 89, 59,
         19, 53, 13, 37, 24,  6, 56,  3, 37,  9, 56, 85, 30, 77, 20, 89, 56, 13,
         64, 64,  1, 74,  4, 85, 22, 52, 17, 77, 64, 17, 53, 53, 47, 28, 70, 51,
         76, 51, 67, 47, 37, 24, 51, 48, 62, 45, 15, 54, 17,  7, 25, 27, 97, 17,
         35, 20, 31, 56, 36, 72, 50, 85, 53, 46, 57, 39, 13, 52, 89, 95, 50, 19,
         57, 27,  2, 22,  4, 43, 37, 85, 17,  4, 51, 37, 50, 20, 69, 77, 11, 89,
         15, 59],
        [28, 19, 37, 13, 90,  3, 64, 93, 52, 89, 45, 91, 70, 15, 57, 19, 51, 53,
         36, 63, 87, 48, 19, 64, 64, 51, 64, 14, 36, 56, 36, 19, 89, 64, 85, 13,
         65, 88,  2, 17, 10, 64, 59, 37,  5, 51, 17, 93, 38, 19, 17, 85, 13, 49,
         52, 71, 20, 81, 45, 51, 37, 17, 39, 31, 11, 78, 91, 15, 28, 35, 20, 36,
         36, 21, 53, 98,  2, 73, 69, 85, 88, 56, 89, 91, 13, 56, 37, 45, 37, 64,
         71,  4, 60, 36, 17, 81, 81, 18, 94, 56, 57, 66, 96, 29,  2, 99, 77, 17,
          0,  3, 66,  2, 69,  0, 53,  5, 17, 37, 13, 36, 45, 89, 37, 19, 96, 19,
         45, 25, 19, 13, 45, 15,  2, 98, 54, 49, 24, 30, 72, 36, 37, 57, 96, 86,
         93, 37, 78, 24, 51, 93, 77, 86, 17, 85, 30, 99, 13, 19, 48, 54,  5, 35,
         15, 75, 53, 47, 94, 13, 89, 37, 90, 91,  3, 81, 40,  3, 94, 19, 15, 69,
         59, 94, 39, 83, 85,  7,  2, 81, 47, 13, 81, 89, 64, 46, 81, 36, 88, 21,
         37, 45],
        [87, 14, 35, 26,  1, 78,  2, 37, 56, 36, 89,  3, 36, 51, 51, 36, 37, 81,
         89,  5, 13, 36, 13, 20, 65, 89, 89, 42, 89, 39, 89, 95, 87, 57, 75,  6,
          2, 85, 51, 31,  0, 37, 89, 89, 89, 64, 10, 89,  2, 53, 84, 20, 43, 67,
         26, 27, 87, 33, 96, 48, 71, 32,  2, 33,  2, 45, 62, 46, 39, 23, 19, 45,
         89,  1, 68, 87, 40,  6, 12, 22, 21, 12, 17, 24, 48, 89, 85, 82,  0, 15,
         42,  1, 89, 11, 52,  2, 55, 52, 36, 38, 67,  5, 17, 80, 17, 87, 64, 71,
          2, 33, 83, 22, 31, 15, 57, 64, 85, 15, 29, 50, 53, 13, 47,  2, 74, 45,
         20, 89, 96, 56, 49, 14, 34, 72, 89, 20,  9, 49, 10, 48, 83, 33, 63, 53,
         39, 89,  4, 37, 81, 37, 17, 81, 88, 89, 48, 15, 89,  1, 57, 64, 95, 64,
         64, 67, 22, 78, 15, 90, 85, 35, 36, 64, 54, 13, 75, 53, 11, 22, 40, 53,
         38,  3, 96, 98, 19, 64,  4,  6,  4, 21, 33, 39, 31,  4, 17, 15, 15, 85,
         61, 47]], device='mps:0')
cthoyt commented 1 year ago

Should it be possible to run knn(x, y) with no explicit batch size?

mberr commented 1 year ago

No, only if you set a default batch size in def knn(x, y, batch_size, k: int = 3):

mberr commented 1 year ago

The result looks as expected:

If you are brave, you can try to increase x / y first dimension until you actually need to lower the batch size. If the behaviour on Apple is similar to the behaviour on Linux, you'll see the following three zones while increasing the maximum memory load

  1. everything runs smooth with a single batch
  2. the library starts decreasing the batch size until it runs through
  3. the OOM killer kills your process in the first iteration of the batch size search

(4. you'll get an OOM error while allocating the input tensors)

cthoyt commented 1 year ago

After increasing the first dimension from 100 up by powers of 10, when I got to

x = torch.rand(100000, 100, device="mps")
y = torch.rand(200000, 100, device="mps")
z = knn(x, y, batch_size=x.shape[0])
print(z)

I got the following error:

Traceback (most recent call last):
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
    z = knn(x, y, batch_size=x.shape[0])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 419, in inner
    result, self.parameter_value[h] = wrapped(*args, **kwargs)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 320, in wrapper_maximize_memory_utilization
    raise error
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 309, in wrapper_maximize_memory_utilization
    func(*bound_arguments.args, **p_kwargs, **bound_arguments.kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 8, in knn
    [
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 9, in <listcomp>
    torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/.virtualenvs/indra/lib/python3.11/site-packages/torch/functional.py", line 1222, in cdist
    return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Invalid buffer size: 74.51 GB
mberr commented 1 year ago

Interesting; this seems to be a cdist-specific OOM-like error? 🤔

EDIT: also described here: https://discuss.pytorch.org/t/runtime-error-invalid-buffer-size-when-calculating-cosine-similarity/152088

mberr commented 1 year ago

Added that error in a0edfbe

cthoyt commented 1 year ago

Great, just pulled and re-ran with the big number. After a totally epic slowdown of my computer and funny audio noises, here's the next output:

Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
Traceback (most recent call last):
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
    z = knn(x, y, batch_size=x.shape[0])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 421, in inner
    result, self.parameter_value[h] = wrapped(*args, **kwargs)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 322, in wrapper_maximize_memory_utilization
    raise error
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 311, in wrapper_maximize_memory_utilization
    func(*bound_arguments.args, **p_kwargs, **bound_arguments.kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 8, in knn
    [
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 9, in <listcomp>
    torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/.virtualenvs/indra/lib/python3.11/site-packages/torch/functional.py", line 1222, in cdist
    return _VF.cdist(x1, x2, p, None)  # type: ignore[attr-defined]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: MPS backend out of memory (MPS allocated: 119.30 MB, other allocations: 43.18 GB, max allowed: 36.27 GB). Tried to allocate 4.76 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).
mberr commented 1 year ago

Re

Should it be possible to run knn(x, y) with no explicit batch size?

It is relatively easy to add an automatic maximum batch size inference on top; you'll need to make sure to decorate after the maximization decorator. The actual calculation of the maximum batch size is problem-specific though, so nothing to easily include here:

from __future__ import annotations

import functools

import torch

from torch_max_mem import maximize_memory_utilization

def infer_maximum_batch_size(f):
    @functools.wraps(f)
    def wrapped(x: torch.Tensor, *args, batch_size: int | None = None, **kwargs):
        if batch_size is None:
            batch_size = x.shape[0]
        return f(x, *args, batch_size=batch_size, **kwargs)

    return wrapped

@infer_maximum_batch_size
@maximize_memory_utilization()
def knn(x, y, batch_size, k: int = 3):
    return torch.cat(
        [
            torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
            for start in range(0, x.shape[0], batch_size)
        ],
        dim=0,
    )

device = "cpu"
x = torch.rand(100, 100, device=device)
y = torch.rand(200, 100, device=device)
knn(x, y)

EDIT: this double decoration is maybe something we can look into for PyKEEN; it also allows to select different maximum batch sizes for different devices.

mberr commented 1 year ago

Great, just pulled and re-ran with the big number. After a totally epic slowdown of my computer and funny audio noises, here's the next output:

ca24138

cthoyt commented 1 year ago

Okay, next error, which is first party:

Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
Traceback (most recent call last):
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
    z = knn(x, y, batch_size=x.shape[0])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 421, in inner
    result, self.parameter_value[h] = wrapped(*args, **kwargs)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 333, in wrapper_maximize_memory_utilization
    raise MemoryError(
MemoryError: Execution did not even succeed with ('batch_size',) all equal to 1.
mberr commented 1 year ago

I already merged https://github.com/mberr/torch-max-mem/pull/15 to bring the warnings and some of the unrelated improvements to main; the actual optimization of batch sizes does not yet seem to work (according to @cthoyt 's tests); I am not sure where the actual error lies 😕

mberr commented 1 year ago

This might be related: https://github.com/pytorch/pytorch/issues/105839

mberr commented 8 months ago

In theory, tests run now on mps, too; however running the example from https://github.com/mberr/torch-max-mem/issues/14#issuecomment-1732274087 does not succeed; the error seems to be related to https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773