danielegrattarola / spektral

Graph Neural Networks with Keras and Tensorflow 2.
https://graphneural.network
MIT License
2.36k stars 334 forks source link

How to do the model.summary() after I have create a model? #449

Open charlespan110 opened 6 months ago

charlespan110 commented 6 months ago

Thank you so much for your open source contribution to the graph neural networks! I have a problem to visualize the model via model.summary, here are my code. When I try to use model.summary it give me error. Could you please help me modify the code to use model.summary() correctly? Thank you so much!

Best

class EEGGraphNet(Model):
    def __init__(self, num_nodes=64, num_features=128, num_classes=2, num_filters=5):
        super().__init__()
        # Graph structure learning module

    def call(self, inputs):
        V, A, I = inputs

        # Graph convolution

        return output

        loader_tr = DisjointLoader(data_tr, batch_size=batch_size, epochs=epochs)
        loader_val = DisjointLoader(data_val, batch_size=batch_size, epochs=epochs)
        loader_te = DisjointLoader(data_te, batch_size=batch_size)

        model = create_model(num_nodes, num_features, num_classes)

        model.summary()

        #model.fit(x_tra, y_tra, epochs=epochs, validation_split=0.1, callbacks=callbacks, batch_size=batch_size)
        model.fit(loader_tr.load(), steps_per_epoch=loader_tr.steps_per_epoch, epochs=epochs,batch_size=batch_size,
                  validation_data=loader_val.load(), validation_steps=loader_val.steps_per_epoch,callbacks=callbacks)
danielegrattarola commented 6 months ago

Hi, Could you post the error that you are getting?