EleutherAI / concept-erasure

Erasing concepts from neural representations with provable guarantees
MIT License
208 stars 15 forks source link

torch.where bug #6

Closed cguerner closed 1 year ago

cguerner commented 1 year ago

Hi,

Reporting a bug with the latest package version concept-erasure 0.2.1, python 10.4 and torch 1.13.0+cu116.

Screen Shot 2023-08-29 at 15 52 50

The second argument to torch.Tensor.where() has to be a tensor, not a float. Fixed it with the following

Lzeros = torch.zeros(L.shape) W = V * L.rsqrt().where(mask, Lzeros) @ V.mT W_inv = V * L.sqrt().where(mask, Lzeros) @ V.mT

norabelrose commented 1 year ago

Hmm this is pretty bizarre because

1) The PyTorch docs for Tensor.where do say that a scalar should work, and this is true both for PyTorch 2.0 and the 1.13 version you're using <img width="647" alt="Captura de pantalla 2023-09-04 a la(s) 12 28 40 a m" src="https://github.com/EleutherAI/concept-erasure/assets/39116809/cdecd074-1fd6-492f-b71e-b6c0c24a9959"> 2) It's working for me when I try it on PyTorch 2.0.

Captura de pantalla 2023-09-04 a la(s) 12 29 53 a m

If you could make a clean repro that would be helpful. I'd also recommend you try PyTorch 2.0, since it's possible that the torch.where behavior did not match the docs in 1.13 but this was fixed in 2.0. If that's the case we might want to bump the minimum required PyTorch version to 2.0.

cemoody commented 1 year ago

Came here to say the same same thing -- observed this issue on a Mac M2 with 1.13

norabelrose commented 1 year ago

I can reproduce with PyTorch 1.13.

Captura de pantalla 2023-09-04 a la(s) 9 16 22 a m

The docs going back to PyTorch 1.7.0 say that scalars are supposed to be allowed in the function torch.where, and Tensor.where is supposed to behave the same way, but it looks like there was actually a discrepancy. The function actually does work on PyTorch 1.13:

Captura de pantalla 2023-09-04 a la(s) 9 23 34 a m

So I think the solution is to use the function instead of the Tensor method.

norabelrose commented 1 year ago

Can @cguerner and/or @cemoody confirm that PR #7 fixes the problem in your environments?