fedelopez77 / sympa

Embedding graphs in symmetric spaces
28 stars 8 forks source link

Extracting Graph Embeddings from the train model #3

Closed sanchit-ahuja closed 2 years ago

sanchit-ahuja commented 2 years ago

Hi, I wanted to know how can I extract the graph embeddings from the trained model for a custom graph input? Thanks!

fedelopez77 commented 2 years ago

Hello,

You can do the following:

model = torch.load(path_to_saved_model)
embeddings = model["model"]["module.embeddings.embeds"]

Now embeddings has a tensor of shape (len(nodes), 2, n, n) where embeddings[:, 0] is the real part of the nxn symmetric matrix, and embeddings[:, 1] is the imaginary part of the matrix.

In case you train embeddings with models other than the Siegels spaces, such as Euclidean, hyperbolic, etc, then embeddings will be a tensor of shape (len(nodes), n).

sanchit-ahuja commented 2 years ago

Thanks for this. This works!