Open egillax opened 2 weeks ago
This is causing me issues with numerical stability (NaN's at some point during training) when training a PoincareEmbedding (similar to the tutorial) with torch.set_default_dtype(torch.float64).
The issue is here:
self.value = Parameter( data=as_tensor(value, dtype=torch.float32), requires_grad=requires_grad, )
Removing the dtype argument makes the curvature respect the torch default dtype.
Hi @egillax thanks for bringing this to our attention! This looks like an easy fix, I'll give it a stab if I get around to it.
Feel free to open a pull request yourself in the meantime if you're interested :)
This is causing me issues with numerical stability (NaN's at some point during training) when training a PoincareEmbedding (similar to the tutorial) with torch.set_default_dtype(torch.float64).
The issue is here:
Removing the dtype argument makes the curvature respect the torch default dtype.