mberr / torch-max-mem

Decorators for maximizing memory utilization with PyTorch & CUDA
https://torch-max-mem.readthedocs.io/en/latest/
MIT License
14 stars 0 forks source link

Refactoring towards MPS support #15

Closed mberr closed 1 year ago

mberr commented 1 year ago

This PR treats mps like cpu, i.e., when warnings are enabled, it will warn about running memory utilization maximization on mps tensors.

It also:

As of now, the actual optimization does not seem to work for mps.

cthoyt commented 1 year ago

Okay, next error, which is first party:

Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
Traceback (most recent call last):
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
    z = knn(x, y, batch_size=x.shape[0])
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 421, in inner
    result, self.parameter_value[h] = wrapped(*args, **kwargs)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 333, in wrapper_maximize_memory_utilization
    raise MemoryError(
MemoryError: Execution did not even succeed with ('batch_size',) all equal to 1.
mberr commented 1 year ago

Okay, next error, which is first party:

Just to be sure: could you check that when you manually pass batch_size=1 you get the same error?

cthoyt commented 1 year ago

~I did manually pass batch_size=1. I didn't use your automated inference code~ When manually setting batch_size=1, I got

Encountered tensors on device_types={'mps'} while only safe_devices=frozenset({'cuda'}) are considered safe for automaticmemory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
Traceback (most recent call last):
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 18, in <module>
    z = knn(x, y, batch_size=1)
        ^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 421, in inner
    result, self.parameter_value[h] = wrapped(*args, **kwargs)
                                      ^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 322, in wrapper_maximize_memory_utilization
    raise error
  File "/Users/cthoyt/dev/torch-max-mem/src/torch_max_mem/api.py", line 311, in wrapper_maximize_memory_utilization
    func(*bound_arguments.args, **p_kwargs, **bound_arguments.kwargs),
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 8, in knn
    [
  File "/Users/cthoyt/dev/torch-max-mem/test.py", line 9, in <listcomp>
    torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
    ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
RuntimeError: selected index k out of range
mberr commented 1 year ago

Ah okay that seems to be a problem with the implementation of batched knn, which does not allow k to be smaller than the batch size; Let me check that function.

Also I think it makes sense to raise from the last RuntimeError to have it in the traceback

mberr commented 1 year ago

There was a typo in the knn function; instead of

def knn(x, y, batch_size, k: int = 3):
    return torch.cat(
        [
            torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=0, largest=False).indices
            for start in range(0, x.shape[0], batch_size)
        ],
        dim=0,
    )

it should be

def knn(x, y, batch_size, k: int = 3):
    return torch.cat(
        [
            torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
            for start in range(0, x.shape[0], batch_size)
        ],
        dim=0,
    )

cf. b5d6e76

cthoyt commented 1 year ago

Final report:

import torch
from torch_max_mem import maximize_memory_utilization

from tqdm import trange
@maximize_memory_utilization()
def knn(x, y, batch_size: int, k: int = 3):
    return torch.cat(
        [
            torch.cdist(x[start : start + batch_size], y).topk(k=k, dim=1, largest=False).indices
            for start in trange(0, x.shape[0], batch_size, unit_scale=True, unit="batch")
        ],
        dim=0,
    )

x = torch.rand(100000, 100, device="mps")
y = torch.rand(200000, 100, device="mps")
z = knn(x, y, batch_size=1)
print(z)

Succeeded

Encountered tensors on device_types={'mps'} while only ['cuda'] are considered safe for automatic memory utilization maximization. This may lead to undocumented crashes (but can be safe, too).
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████| 100k/100k [07:20<00:00, 227batch/s]
tensor([[  6336, 197116,   2431],
        [170591, 153052, 136859],
        [114449,  36699, 165175],
        ...,
        [185708,  47737, 140548],
        [166128,  41156, 171905],
        [121294, 166199,  96834]], device='mps:0')

It took around 5 minutes after the tqdm loop was done to print the final results. Switching batch size to 1,000 made it run in 18 seconds, but slowed my computer down a bit ;)