openai / transformer-debugger

MIT License
4.01k stars 231 forks source link

Compatibility with MPS backend #16

Open shehper opened 5 months ago

shehper commented 5 months ago

While running inference on my Mac with MacOS version 13.1, I received the following error:

RuntimeError: MPS does not support cumsum_out_mps op with int64 input. Support has been added in macOS 13.3

I received this error because of the use of cumsum in prep_pos_from_pad_and_prev_lens. cumsum is also used in other places in the repository.

The same error arises when running various tests on the MPS backend, as mentioned around the function get_default_device. I have checked that this error is also because of the inability of MacOS to compute cumsum with version < 13.3.

Screenshot 2024-03-19 at 5 26 14 PM

Should we modify the function get_default_device to return torch.device("mps", 0) only when MacOS version >= 13.3? We can remove the current workaround that avoids running pytests with this backend.

If this seems like a useful change, I will be happy to submit a pull request.

Thank you!

shehper commented 5 months ago

I dug in a bit more and am sharing my findings here.

After updating MacOS to version 13.6, the error shared above disappeared. However, test_interactive_model.py still fails on mps backend. The error appears due to the computation of torch.einsum here. The error is copied below.

failed assertion `[MPSNDArrayDescriptor sliceDimension:withSubrange:] error: subRange.start (2) is not less than length of dimension[1] (2)'

Other people have faced similar errors when working with tensors on mps backend. See PyTorch issues # 96153 and 113586, which are still open.

For what it's worth, replacing torch.einsum("td,td->t", residual, direction)[:, None] with

torch.sum(torch.mul(residual, direction), dim=1, keepdim=True)

solves the issue. In summary, the solution involves:

  1. checking if MacOS version < 13.3. If yes, use cpu.
  2. replacing torch.einsum with the code in the block above.

With these changes, we can remove the workaround in get_default_device.

WuTheFWasThat commented 4 months ago

thanks for looking into this! PR is welcome