Open mieszkokl opened 1 year ago
You are right, 1e-12
is too small for torch.HalfTensor
:
>>> torch.HalfTensor([1e-12])
tensor([0.], dtype=torch.float16)
However, the only way to run into this problem is if the tensors out0
and out1
have a norm smaller than 1e-12
which is a very unlikely scenario and could hint at a possible bug in your code. That being said I think your fix is reasonable and we'd welcome your pull request.
When training with half-precision I noticed that normalization in NTXentLoss can give
NaN
values.in
forward
method, there is a code:It uses
torch.nn.functional.normalize
function with default1e-12
epsilon, what gives 0 for half precision. As a result we have division by zero andNaN
in output.The way to solve it is to add optional normalization epsilon parameter in NTXentLoss initializer and use it when calling
torch.nn.functional.normalize
function.Please let me know if there is any mistake in my understanding. If it's okay for you, I can propose a pull request.