Open cthoyt opened 1 year ago
@cthoyt , if you have time, you could check #15 with the example from https://github.com/mberr/torch-max-mem#-getting-started
Hang on, the current status does explicitly not warn for mps by default 😅
I ran
import torch
from torch_max_mem import maximize_memory_utilization
@maximize_memory_utilization()
def knn(x, y, batch_size, k: int = 3):
return torch.cat(
[
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
for start in range(0, x.shape[0], batch_size)
],
dim=0,
)
x = torch.rand(100, 100, device="mps")
y = torch.rand(200, 100, device="mps")
z = knn(x, y, batch_size=x.shape[0])
print(z)
Results:
Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
tensor([[ 2, 17, 58, 19, 29, 11, 72, 89, 94, 58, 17, 81, 41, 24, 64, 37, 38, 77,
76, 33, 17, 22, 7, 37, 7, 19, 72, 11, 64, 15, 15, 6, 49, 77, 22, 69,
3, 81, 36, 89, 6, 50, 19, 82, 45, 81, 4, 37, 37, 89, 39, 30, 2, 37,
89, 63, 48, 19, 30, 37, 93, 13, 36, 56, 53, 35, 45, 20, 88, 17, 51, 49,
37, 89, 19, 39, 1, 71, 28, 14, 89, 41, 13, 75, 45, 83, 72, 20, 46, 3,
53, 49, 39, 4, 85, 19, 28, 19, 88, 81, 19, 70, 56, 13, 19, 19, 89, 59,
19, 53, 13, 37, 24, 6, 56, 3, 37, 9, 56, 85, 30, 77, 20, 89, 56, 13,
64, 64, 1, 74, 4, 85, 22, 52, 17, 77, 64, 17, 53, 53, 47, 28, 70, 51,
76, 51, 67, 47, 37, 24, 51, 48, 62, 45, 15, 54, 17, 7, 25, 27, 97, 17,
35, 20, 31, 56, 36, 72, 50, 85, 53, 46, 57, 39, 13, 52, 89, 95, 50, 19,
57, 27, 2, 22, 4, 43, 37, 85, 17, 4, 51, 37, 50, 20, 69, 77, 11, 89,
15, 59],
[28, 19, 37, 13, 90, 3, 64, 93, 52, 89, 45, 91, 70, 15, 57, 19, 51, 53,
36, 63, 87, 48, 19, 64, 64, 51, 64, 14, 36, 56, 36, 19, 89, 64, 85, 13,
65, 88, 2, 17, 10, 64, 59, 37, 5, 51, 17, 93, 38, 19, 17, 85, 13, 49,
52, 71, 20, 81, 45, 51, 37, 17, 39, 31, 11, 78, 91, 15, 28, 35, 20, 36,
36, 21, 53, 98, 2, 73, 69, 85, 88, 56, 89, 91, 13, 56, 37, 45, 37, 64,
71, 4, 60, 36, 17, 81, 81, 18, 94, 56, 57, 66, 96, 29, 2, 99, 77, 17,
0, 3, 66, 2, 69, 0, 53, 5, 17, 37, 13, 36, 45, 89, 37, 19, 96, 19,
45, 25, 19, 13, 45, 15, 2, 98, 54, 49, 24, 30, 72, 36, 37, 57, 96, 86,
93, 37, 78, 24, 51, 93, 77, 86, 17, 85, 30, 99, 13, 19, 48, 54, 5, 35,
15, 75, 53, 47, 94, 13, 89, 37, 90, 91, 3, 81, 40, 3, 94, 19, 15, 69,
59, 94, 39, 83, 85, 7, 2, 81, 47, 13, 81, 89, 64, 46, 81, 36, 88, 21,
37, 45],
[87, 14, 35, 26, 1, 78, 2, 37, 56, 36, 89, 3, 36, 51, 51, 36, 37, 81,
89, 5, 13, 36, 13, 20, 65, 89, 89, 42, 89, 39, 89, 95, 87, 57, 75, 6,
2, 85, 51, 31, 0, 37, 89, 89, 89, 64, 10, 89, 2, 53, 84, 20, 43, 67,
26, 27, 87, 33, 96, 48, 71, 32, 2, 33, 2, 45, 62, 46, 39, 23, 19, 45,
89, 1, 68, 87, 40, 6, 12, 22, 21, 12, 17, 24, 48, 89, 85, 82, 0, 15,
42, 1, 89, 11, 52, 2, 55, 52, 36, 38, 67, 5, 17, 80, 17, 87, 64, 71,
2, 33, 83, 22, 31, 15, 57, 64, 85, 15, 29, 50, 53, 13, 47, 2, 74, 45,
20, 89, 96, 56, 49, 14, 34, 72, 89, 20, 9, 49, 10, 48, 83, 33, 63, 53,
39, 89, 4, 37, 81, 37, 17, 81, 88, 89, 48, 15, 89, 1, 57, 64, 95, 64,
64, 67, 22, 78, 15, 90, 85, 35, 36, 64, 54, 13, 75, 53, 11, 22, 40, 53,
38, 3, 96, 98, 19, 64, 4, 6, 4, 21, 33, 39, 31, 4, 17, 15, 15, 85,
61, 47]], device='mps:0')
Should it be possible to run knn(x, y)
with no explicit batch size?
No, only if you set a default batch size in def knn(x, y, batch_size, k: int = 3):
The result looks as expected:
mps
not being considered a safe device to run AMO onIf you are brave, you can try to increase x
/ y
first dimension until you actually need to lower the batch size. If the behaviour on Apple is similar to the behaviour on Linux, you'll see the following three zones while increasing the maximum memory load
(4. you'll get an OOM error while allocating the input tensors)
After increasing the first dimension from 100 up by powers of 10, when I got to
x = torch.rand(100000, 100, device="mps")
y = torch.rand(200000, 100, device="mps")
z = knn(x, y, batch_size=x.shape[0])
print(z)
I got the following error:
Traceback (most recent call last):
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
z = knn(x, y, batch_size=x.shape[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 419, in inner
result, self.parameter_value[h] = wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 320, in wrapper_maximize_memory_utilization
raise error
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 309, in wrapper_maximize_memory_utilization
func(*bound_arguments.args, **p_kwargs, **bound_arguments.kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 8, in knn
[
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 9, in <listcomp>
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/.virtualenvs/indra/lib/python3.11/site-packages/torch/functional.py", line 1222, in cdist
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: Invalid buffer size: 74.51 GB
Interesting; this seems to be a cdist
-specific OOM-like error? 🤔
EDIT: also described here: https://discuss.pytorch.org/t/runtime-error-invalid-buffer-size-when-calculating-cosine-similarity/152088
Great, just pulled and re-ran with the big number. After a totally epic slowdown of my computer and funny audio noises, here's the next output:
Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
Traceback (most recent call last):
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
z = knn(x, y, batch_size=x.shape[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 421, in inner
result, self.parameter_value[h] = wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 322, in wrapper_maximize_memory_utilization
raise error
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 311, in wrapper_maximize_memory_utilization
func(*bound_arguments.args, **p_kwargs, **bound_arguments.kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 8, in knn
[
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 9, in <listcomp>
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/.virtualenvs/indra/lib/python3.11/site-packages/torch/functional.py", line 1222, in cdist
return _VF.cdist(x1, x2, p, None) # type: ignore[attr-defined]
^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: MPS backend out of memory (MPS allocated: 119.30 MB, other allocations: 43.18 GB, max allowed: 36.27 GB). Tried to allocate 4.76 MB on private pool. Use PYTORCH_MPS_HIGH_WATERMARK_RATIO=0.0 to disable upper limit for memory allocations (may cause system failure).
Re
Should it be possible to run knn(x, y) with no explicit batch size?
It is relatively easy to add an automatic maximum batch size inference on top; you'll need to make sure to decorate after the maximization decorator. The actual calculation of the maximum batch size is problem-specific though, so nothing to easily include here:
from __future__ import annotations
import functools
import torch
from torch_max_mem import maximize_memory_utilization
def infer_maximum_batch_size(f):
@functools.wraps(f)
def wrapped(x: torch.Tensor, *args, batch_size: int | None = None, **kwargs):
if batch_size is None:
batch_size = x.shape[0]
return f(x, *args, batch_size=batch_size, **kwargs)
return wrapped
@infer_maximum_batch_size
@maximize_memory_utilization()
def knn(x, y, batch_size, k: int = 3):
return torch.cat(
[
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
for start in range(0, x.shape[0], batch_size)
],
dim=0,
)
device = "cpu"
x = torch.rand(100, 100, device=device)
y = torch.rand(200, 100, device=device)
knn(x, y)
EDIT: this double decoration is maybe something we can look into for PyKEEN; it also allows to select different maximum batch sizes for different devices.
Great, just pulled and re-ran with the big number. After a totally epic slowdown of my computer and funny audio noises, here's the next output:
Okay, next error, which is first party:
Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
Traceback (most recent call last):
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
z = knn(x, y, batch_size=x.shape[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 421, in inner
result, self.parameter_value[h] = wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 333, in wrapper_maximize_memory_utilization
raise MemoryError(
MemoryError: Execution did not even succeed with ('batch_size',) all equal to 1.
I already merged https://github.com/mberr/torch-max-mem/pull/15 to bring the warnings and some of the unrelated improvements to main
; the actual optimization of batch sizes does not yet seem to work (according to @cthoyt 's tests); I am not sure where the actual error lies 😕
This might be related: https://github.com/pytorch/pytorch/issues/105839
In theory, tests run now on mps, too; however running the example from https://github.com/mberr/torch-max-mem/issues/14#issuecomment-1732274087 does not succeed; the error seems to be related to https://discuss.pytorch.org/t/mps-back-end-out-of-memory-on-github-action/189773
the improvements to pykeen in https://github.com/pykeen/pykeen/pull/975 don't work now that we have totally externalized AMO to this package - can we think about how to make sure that the CPU warnings also apply to MPS, since it has shared memory and somehow can get out of hand in the same way?
reminder I have an apple system with MPS so am happy to test