wwu-mmll / photonai_graph

Photon Graph is an extension for the PHOTON framework that allows for the use of machine learning based on graphs.
https://wwu-mmll.github.io/photonai_graph/
MIT License
9 stars 1 forks source link

Implement dgl check for graphs with different numbers of nodes #117

Open VHolstein opened 1 year ago

VHolstein commented 1 year ago

Is your feature request related to a problem? Please describe. PHOTONAI-Graph only supports dgl graphs with the same number of node features. If you use dgl graphs with varying numbers of nodes/features (eg, dgl mini dataset), the pipeline does not work.

Describe the solution you'd like Implement a new architecture that uses an MLP to sample down the number of node features to a fixed size. Also build a check that allows for using node degree from within the different architectures.

Additional context This will be important for version 2.0

Unit Tests Reimplement these unit tests that use the dgl mini dataset

    def test_gcn_classifier_dgl(self):
        gat_clf = GCNClassifierModel(nn_epochs=20)
        gat_clf.fit(self.X_dgl, self.y)
        output = gat_clf.predict(self.X_dgl)
        self.assertTrue(np.array_equal(np.array(output.shape), self.y.shape))
    def test_gat_classifier_dgl(self):
        gat_clf = GATClassifierModel(nn_epochs=20)
        gat_clf.fit(self.X_dgl, self.y)
        output = gat_clf.predict(self.X_dgl)
        self.assertEqual(output.shape, self.y.shape)