pyg-team / pytorch_geometric

Graph Neural Network Library for PyTorch
https://pyg.org
MIT License
20.54k stars 3.57k forks source link

Jumping Knowledge support for Heterogeneous Graph Features #9355

Closed m-atalla closed 3 weeks ago

m-atalla commented 1 month ago

🚀 The feature, motivation and pitch

Dear PyG Developers,

Currently the Jumping Knowledge module (JK) only supports homogeneous features input, that is inputs with the type signature List[Tensor]

This feature would promote JK as a potential drop-in improvement when experimenting on relevant heterogeneous tasks. I propose adding the class HeteroJK for heterogeneous inputs (for brevity's sake, I omitted comments/imports):

class HeteroJK(torch.nn.Module):
    def __init__(self, metadata: Metadata, **kwargs):
        super().__init__()
        self.jk = ModuleDict()
        for node_type in metadata[0]:
            self.jk[node_type] = JumpingKnowledge(**kwargs)

    def forward(self, dict_xs: Dict[str, List[Tensor]]) -> Dict[str, List[Tensor]]:
        out_dict = {}
        for node_type, xs in dict_xs.items():
            out_dict[node_type] = self.jk[node_type](xs)
        return out_dict

Alternatives

Modify the existing JumpingKnowledge module so that It'd also accept heterogeneous inputs, which would alter the type signature of the input to be Union[List[Tensor], Dict[str, List[Tensor]]], the return type would be changed similarly.

I'm happy to help implement either approach if this feature is a welcomed addition to PyG.

Thank you for your time and consideration.

Additional context

No response

rusty1s commented 1 month ago

This sounds like a good feature to me. Feel free to send a PR :)