mberr / torch-max-mem

Decorators for maximizing memory utilization with PyTorch & CUDA
https://torch-max-mem.readthedocs.io/en/latest/
MIT License
14 stars 0 forks source link

Re-enable tests for mac #19

Open mberr opened 2 months ago

cthoyt commented 2 months ago

Segfault is getting emitted by the cdist function, can get reproduced on my machine with:

import torch

x = torch.rand(100000, 100, device="mps")
y = torch.rand(200000, 100, device="mps")

# The following 2 lines return 
# RuntimeError: Invalid buffer size: 74.51 GB
torch.cdist(x, y)

# The following is the smallest that causes invalid buffer size
x_batch = x[:27180]
torch.cdist(x, y)

# The following is the largest that causes a segfault
x_batch = x[:27179] # size = 10,871,600 bytes (bigger than 2**32)
torch.cdist(x, y)

# The following is the smallest that causes a segfault
x_batch = x[:21475] # size = 8,590,000 bytes (less than 2**32)
torch.cdist(x_batch, y)

# Small enough, all good
x_batch = x[:21474]
torch.cdist(x_batch, y)

Here's the segfault:

/AppleInternal/Library/BuildRoots/91a344b1-f985-11ee-b563-fe8bc7981bff/Library/Caches/com.apple.xbs/Sources/MetalPerformanceShaders/MPSCore/Types/MPSNDArray.mm:788: failed assertion `[MPSNDArray initWithDevice:descriptor:] Error: total bytes of NDArray > 2**32'

I was thinking I could check the element sizes with footprint = x_batch.nelement() * x_batch.element_size(), but this doesn't seem to match up to the error, so I guess I'm counting wrong

mberr commented 2 months ago

I was thinking I could check the element sizes with footprint = x_batch.nelement() * x_batch.element_size(), but this doesn't seem to match up to the error, so I guess I'm counting wrong

torch.cdist is calculating the pairwise distances, so its result shape is x.shape[0] * y.shape[0], and indeed we have

21_474 * 200_000 = 4_294_800_000
2**32            = 4_294_967_296
21_475 * 200_000 = 4_295_000_000

Thanks for debugging!

I'll just update the example to something small enough, although it is a bit unfortunate that this raises a SEGFAULT rather than causing a catch-able exception. I also need to add this somewhere to the documentation so a user encountering the issue has some starting point.

mberr commented 2 months ago

Hm, looks like the segfault still occurs.

This discussion might be related: https://discuss.pytorch.org/t/segmentation-fault-with-pytorch-2-3/203381