DeepGraphLearning / ULTRA

A foundation model for knowledge graph reasoning
MIT License
439 stars 59 forks source link

Integrating a language model with ULTRA #9

Open daniel4x opened 8 months ago

daniel4x commented 8 months ago

Hi @migalkin, First of all, Kudus for your work!!!! (both ULTRA and nodepiece 😄 ) .

I'm curious to hear your thoughts about integrating a language model (LM) with ULTRA. Previously, with other KG models such as nodepiece, it was straightforward to integrate a language model to enrich the graph embeddings with textual embeddings. I used to concat both the entity textual and graph representations and maybe apply additional layers to match the desired dimensions.

example:

# code from pykeen framework + modification
x_e, x_r = entity_representations[0](), self.relation_representations[0]()
indicies = torch.arange(self.text_representation.weight.data.shape[0])
x_e = self.merge_model(self.text_representation(indicies), x_e)  # Concat + linear layer

# Perform message passing and get updated states
for layer in self.gnn_encoder:
        x_e, x_r = layer(
            x_e=x_e,
            x_r=x_r,
            edge_index=getattr(self, f"{mode}_edge_index"),
            edge_type=getattr(self, f"{mode}_edge_type"),
        )

So far, it worked well and boosted the model's performance from ~50% when used with transE and up to ~30% with nodepiece on my datasets.

With ULTRA I guess that I have some additional work to do :)... I started with understanding how the entity representation is "generated" on the fly: https://github.com/DeepGraphLearning/ULTRA/blob/33c6e6b8e522aed3d33f6ce5d3a1883ca9284718/ultra/models.py#L166-L174C4

I understand that from that point only the tail representations are used to feed the MLP.

I replaced the MLP with my own MLP - to match the dim to the concatenation of both representations. Then, I tried to contact both, output from ULTRA with the textual entity representation. As far as I understand, due to this "late" concatenation only the tail entity textual representation will be used. When tested, I got (almost) the same results with/without the textual representation.

Not sure what I expect to hear :), but I hope you may have an idea for combining both representations.

migalkin commented 8 months ago

Hi!

I understand that from that point only the tail representations are used to feed the MLP.

Those aren't really tail representations anymore because message passing updates all node states and starts with initial node states called boundary (where you can append LLM features):

https://github.com/DeepGraphLearning/ULTRA/blob/33c6e6b8e522aed3d33f6ce5d3a1883ca9284718/ultra/models.py#L137-L140

In the GNN layer code, we use those boundary states together with the current message (eg, with the sum aggregation): https://github.com/DeepGraphLearning/ULTRA/blob/33c6e6b8e522aed3d33f6ce5d3a1883ca9284718/ultra/layers.py#L192-L194

That is, in each GNN layer we actually do have an interaction function of (initial) head and (current) node states.

Adding other entity/relation features seems quite straightforward, I see two possible ways:

  1. Early interaction - you'd need to re-train the model from scratch because all weight matrices will be of different dimensions

  2. Late interaction - you can freeze the main GNN models and only change the final MLP by adding LLM features to the output: https://github.com/DeepGraphLearning/ULTRA/blob/33c6e6b8e522aed3d33f6ce5d3a1883ca9284718/ultra/models.py#L199-L200

This is a less expressive way but won't require re-training the model from scratch.

In any case, if your graphs have >10k nodes, I'd recommend projecting down the LLM features (usually 768d or more, depends on the LLM) to smaller dimension (32/64d) in order to fit the full-batch GNN layer onto a GPU.

daniel4x commented 8 months ago

Just an update, I tested all three suggested methods. I'll add a side branch later if you are interested, along with an example of language model integration for future reference.

Generally, the pre-trained embedding was added to the EntityNBFNet __init__:

https://github.com/DeepGraphLearning/ULTRA/blob/04d5c13a440a1b72be3a0208fcc92e7242cab7a5/ultra/models.py#L106-L127

Slightly modified the code and added the following:

        if lm_vectors is not None:
            # can decide whether to freeze or not...
            self.lm_vectors = nn.Embedding.from_pretrained(lm_vectors, freeze=True)
            self.merge_linear = nn.Linear(feature_dim, 64)

Per your 1st suggestion, it seems like training from scratch with the following:

       # .....original code.....
        boundary.scatter_add_(1, index.unsqueeze(1), query.unsqueeze(1))

        # interaction of boundary with lm vectors
        if self.lm_vectors is not None:
            lm_vectors = self.lm_vectors(h_index)  # mistake - see @migalkin
            lm_vectors = lm_vectors.unsqueeze(1).expand(-1, data.num_nodes, -1)
            boundary = torch.cat([boundary, lm_vectors], dim=-1)
            boundary = self.merge_linear(boundary)

The merge_linear may not be the best option and can be modified with any other interaction to fit into the Conv layers.

migalkin commented 8 months ago

Those lines

lm_vectors = self.lm_vectors(h_index)
lm_vectors = lm_vectors.unsqueeze(1).expand(-1, data.num_nodes, -1)

would take only lm features of head nodes in the batch and copy them to all nodes in the graph - is this what you want?

If you want to initialize each node with its own lm feature, then you don't need to call the embedding layer and just take all of its weights (and repeat along the batch dimension) like self.lm_vectors.weight.repeat(bs, 1, 1) or sending the whole number of nodes as the index self.lm_vectors(torch.arange(data.num_nodes)).repeat(bs, 1, 1)

dhall1995 commented 7 months ago

@daniel4x Have you made this into a separate branch? I would be really interested to see your code and hear which of the integration methods performed best for you.

In my use case I have a mixture of LLM features (edges) and a set of pre-trained embedding features for each of the node types. Your experience that node features offer significant performance benefits tallies a lot with mine so it would be great to integrate this into my code.