Closed mberr closed 1 year ago
Okay, next error, which is first party:
Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
Traceback (most recent call last):
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
z = knn(x, y, batch_size=x.shape[0])
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 421, in inner
result, self.parameter_value[h] = wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 333, in wrapper_maximize_memory_utilization
raise MemoryError(
MemoryError: Execution did not even succeed with ('batch_size',) all equal to 1.
Okay, next error, which is first party:
Just to be sure: could you check that when you manually pass batch_size=1
you get the same error?
~I did manually pass batch_size=1. I didn't use your automated inference code~
When manually setting batch_size=1
, I got
Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
Traceback (most recent call last):
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
z = knn(x, y, batch_size=1)
^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 421, in inner
result, self.parameter_value[h] = wrapped(*args, **kwargs)
^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 322, in wrapper_maximize_memory_utilization
raise error
File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 311, in wrapper_maximize_memory_utilization
func(*bound_arguments.args, **p_kwargs, **bound_arguments.kwargs),
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 8, in knn
[
File "/Users/cthoyt/dev/torch-max-mem/test.py", line 9, in <listcomp>
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: selected index k out of range
Ah okay that seems to be a problem with the implementation of batched knn, which does not allow k to be smaller than the batch size; Let me check that function.
Also I think it makes sense to raise from the last RuntimeError
to have it in the traceback
There was a typo in the knn
function; instead of
def knn(x, y, batch_size, k: int = 3):
return torch.cat(
[
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
for start in range(0, x.shape[0], batch_size)
],
dim=0,
)
it should be
def knn(x, y, batch_size, k: int = 3):
return torch.cat(
[
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
for start in range(0, x.shape[0], batch_size)
],
dim=0,
)
cf. b5d6e76
Final report:
import torch
from torch_max_mem import maximize_memory_utilization
from tqdm import trange
@maximize_memory_utilization()
def knn(x, y, batch_size: int, k: int = 3):
return torch.cat(
[
torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
for start in trange(0, x.shape[0], batch_size, unit_scale=True, unit="batch")
],
dim=0,
)
x = torch.rand(100000, 100, device="mps")
y = torch.rand(200000, 100, device="mps")
z = knn(x, y, batch_size=1)
print(z)
Succeeded
Encountered tensors on device_types={'mps'} while only ['cuda'] are considered safe for automatic memory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100k/100k [07:20<00:00, 227batch/s]
tensor([[ 6336, 197116, 2431],
[170591, 153052, 136859],
[114449, 36699, 165175],
...,
[185708, 47737, 140548],
[166128, 41156, 171905],
[121294, 166199, 96834]], device='mps:0')
It took around 5 minutes after the tqdm loop was done to print the final results. Switching batch size to 1,000 made it run in 18 seconds, but slowed my computer down a bit ;)
This PR treats
mps
likecpu
, i.e., when warnings are enabled, it will warn about running memory utilization maximization onmps
tensors.It also:
RuntimeError
)MemoryError
As of now, the actual optimization does not seem to work for
mps
.