Open nisheethlahoti opened 2 months ago
Thanks for reporting this @nisheethlahoti! I can reproduce this locally as well.
My first guess on the cause of the error was that negative axes were being converted incorrectly somewhere using ndim - 1 + dim
instead of ndim + dim
, but this error happens only when dim <= -3, and it seems unlikely we're handling dim=-1 or -2 separately from everything else.
🐛 Describe the bug
If you run the code below:
The expected output is
While the actual output is
On CPU we get the expected output. Also, relatedly, trying to do
x.sum(-5)
works just fine on CPU, but throws IndexError on MPS.Versions
FYI, I'm using micromamba, which I think the collect_env script should ideally club with conda, but doesn't.
cc @kulinseth @albanD @malfet @DenisVieriu97 @jhavukainen