pytorch / maskedtensor

MaskedTensors for PyTorch
https://pytorch.org/maskedtensor
Other
38 stars 10 forks source link

Inconsistent sum behavior between boolean regular tensor and maskedtensor #54

Closed lazycal closed 2 years ago

lazycal commented 2 years ago

🐛 Describe the bug

Reporting a possible inconsistent-semantic issue that might be worth looking at:

import torch
from maskedtensor import masked_tensor

t = torch.tensor([True, True, True])
print(t.sum().item())
print(masked_tensor(t, t).sum().item())

In the above code, regular tensor sums to 3 but masked tensor sums to True, but I expected them to be the same. (Casting to float solves my issue and I have no problem with that. Just thought it might be helpful.)

lazycal commented 2 years ago
import torch
from maskedtensor import masked_tensor, as_masked_tensor

t = torch.tensor([True, True, True])
mask = torch.tensor([True, True, True])
print(masked_tensor(t, mask).to(torch.float))

Actually, I just hit something that is more confusing to me. It looks like .to(torch.float) does not quite work for maskedtensor? The above code prints


masked_tensor(
  [True, True, True]
)```
but I would expect somethign like `[1., 1., 1.]`.
george-qi commented 2 years ago

Hi @lazycal, thanks for calling out this issue again and continuing to use MaskedTensor!

For the .to example, this should be fixed with PR #55 as indicated above:

>>> from maskedtensor import masked_tensor, as_masked_tensor
>>>
>>> t = torch.tensor([True, True, True])
>>> mask = torch.tensor([True, True, True])
>>> print(masked_tensor(t, mask).to(torch.float))
masked_tensor(
  [  1.0000,   1.0000,   1.0000]
)

However, the first issue with .sum() is due to type promotion not being implemented with masked reductions at the moment. We correctly calculate torch.sum(torch.tensor([True, True, True])) -> torch.tensor(3), but at the end of the operation, the masked reduction will convert the result back to the original dtype of the input, which is bool, hence the True result.

I've created an issue #57 to help keep track of this. I hope this answers your question, and please keep the feedback coming. Really appreciate it!