maxvanspengler / hyperbolic_learning_library

An extension of the PyTorch library containing various tools for performing deep learning in hyperbolic space.
MIT License
137 stars 9 forks source link

Curvature in PoincareBall hardcoded as float32 torch tensor #62

Open egillax opened 2 weeks ago

egillax commented 2 weeks ago

This is causing me issues with numerical stability (NaN's at some point during training) when training a PoincareEmbedding (similar to the tutorial) with torch.set_default_dtype(torch.float64).

The issue is here:

        self.value = Parameter(
            data=as_tensor(value, dtype=torch.float32),
            requires_grad=requires_grad,
        )

Removing the dtype argument makes the curvature respect the torch default dtype.

philippmwirth commented 1 week ago

Hi @egillax thanks for bringing this to our attention! This looks like an easy fix, I'll give it a stab if I get around to it.

Feel free to open a pull request yourself in the meantime if you're interested :)