Closed cguerner closed 1 year ago
Hmm this is pretty bizarre because
1) The PyTorch docs for Tensor.where
do say that a scalar should work, and this is true both for PyTorch 2.0 and the 1.13 version you're using
<img width="647" alt="Captura de pantalla 2023-09-04 a la(s) 12 28 40 a m"
src="https://github.com/EleutherAI/concept-erasure/assets/39116809/cdecd074-1fd6-492f-b71e-b6c0c24a9959">
2) It's working for me when I try it on PyTorch 2.0.
If you could make a clean repro that would be helpful. I'd also recommend you try PyTorch 2.0, since it's possible that the torch.where
behavior did not match the docs in 1.13 but this was fixed in 2.0. If that's the case we might want to bump the minimum required PyTorch version to 2.0.
Came here to say the same same thing -- observed this issue on a Mac M2 with 1.13
I can reproduce with PyTorch 1.13.
The docs going back to PyTorch 1.7.0 say that scalars are supposed to be allowed in the function torch.where
, and Tensor.where
is supposed to behave the same way, but it looks like there was actually a discrepancy. The function actually does work on PyTorch 1.13:
So I think the solution is to use the function instead of the Tensor
method.
Can @cguerner and/or @cemoody confirm that PR #7 fixes the problem in your environments?
Hi,
Reporting a bug with the latest package version concept-erasure 0.2.1, python 10.4 and torch 1.13.0+cu116.
The second argument to torch.Tensor.where() has to be a tensor, not a float. Fixed it with the following
Lzeros = torch.zeros(L.shape) W = V * L.rsqrt().where(mask, Lzeros) @ V.mT W_inv = V * L.sqrt().where(mask, Lzeros) @ V.mT