Open Sum02dean opened 2 years ago
Should I remove the graph mask from the model init() and call() methods? E.g: self.mask = GraphMasking()
and laterx=self.mask(x)
, and any reference to masking from the BatchLoader()?
################################################################################
# Load data
################################################################################
loader_tr = BatchLoader(dataset_tr, batch_size=batch_size)
loader_va = BatchLoader(dataset_va, batch_size=batch_size)
loader_te = BatchLoader(dataset_te, batch_size=batch_size)
################################################################################
# Build model
################################################################################
class Net(Model):
def __init__(self):
super().__init__()
# self.mask = GraphMasking()
self.conv1 = GCNConv(32, activation="relu")
self.pool = MinCutPool(N // 2)
self.conv2 = GCNConv(32, activation="relu")
self.global_pool = GlobalSumPool()
self.dense1 = Dense(n_out)
def call(self, inputs):
x, a = inputs
# x = self.mask(x)
x = self.conv1([x, a])
x_pool, a_pool = self.pool([x, a])
x_pool = self.conv2([x_pool, a_pool])
output = self.global_pool(x_pool)
output = self.dense1(output)
return output
model = Net()
opt = Adam(lr=learning_rate)
model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["acc"])
As a solution, I have copy and pasted the GraphMasking() class into my script and removed the masking Bool from the BatchLoader() functions, while keeping the masking references in the model init() and call() method. E.g.
class GraphMasking(Layer):
"""
A layer that starts the propagation of masks in a model.
This layer assumes that the node features given as input have been extended with a
binary mask that indicates which nodes are valid in each graph.
The layer is useful when using a `data.BatchLoader` with `mask=True` or in general
when zero-padding graphs so that all batches have the same size. The binary mask
indicates with a 1 those nodes that should be taken into account by the model.
The layer will remove the rightmost feature from the nodes and start a mask
propagation to all subsequent layers:
```python
print(x.shape) # shape (batch, n_nodes, n_node_features + 1)
mask = x[..., -1:] # shape (batch, n_nodes, 1)
x_new = x[..., :-1] # shape (batch, n_nodes, n_node_features)
"""
def compute_mask(self, inputs, mask=None):
x = inputs[0] if isinstance(inputs, list) else inputs
return x[..., -1:]
def call(self, inputs, **kwargs):
# Remove mask from features
if isinstance(inputs, list):
inputs[0] = inputs[0][..., :-1]
else:
inputs = inputs[..., :-1]
return inputs
################################################################################
################################################################################ learning_rate = 1e-3 # Learning rate epochs = 100 # Number of training epochs batch_size = 32 # Batch size
################################################################################
################################################################################
loader_tr = BatchLoader(dataset_tr, batch_size=batch_size) loader_va = BatchLoader(dataset_va, batch_size=batch_size) loader_te = BatchLoader(dataset_te, batch_size=batch_size)
################################################################################
################################################################################ class Net(Model): def init(self): super().init() self.mask = GraphMasking() self.conv1 = GCNConv(32, activation="relu") self.pool = MinCutPool(N // 2) self.conv2 = GCNConv(32, activation="relu") self.global_pool = GlobalSumPool() self.dense1 = Dense(n_out)
def call(self, inputs):
x, a = inputs
x = self.mask(x)
x = self.conv1([x, a])
x_pool, a_pool = self.pool([x, a])
x_pool = self.conv2([x_pool, a_pool])
output = self.global_pool(x_pool)
output = self.dense1(output)
return output
model = Net() opt = Adam(lr=learning_rate) model.compile(optimizer=opt, loss="categorical_crossentropy", metrics=["acc"])
################################################################################
################################################################################ model.fit( loader_tr.load(), steps_per_epoch=loader_tr.steps_per_epoch, epochs=epochs, validation_data=loader_va, validation_steps=loader_va.steps_per_epoch, callbacks=[EarlyStopping(patience=10, restore_best_weights=True)], )
################################################################################
################################################################################ print("Testing model") loss, acc = model.evaluate(loader_te.load(), steps=loader_te.steps_per_epoch) print("Done. Test loss: {}. Test acc: {}".format(loss, acc))
Using this temporary solution, the model seems to learn, but there is no knowing on my behalf what is actually happening under the hood, i.e. is making happening or not.
Hi,
what version of Spektral is this happening on? I cannot reproduce the issue on 1.0.8.
Do you still get the issue if you install from source?
Thanks
Hi @Sum02dean
I dont know if this helps now but you can import GraphMasking from spektral.layers.base
Is GraphMasking now redundant? I am making modifications to my code based on the following recommendation: https://github.com/danielegrattarola/spektral/issues/283. However, some of this code no longer seems usable.
Also I get an error here:
Many thanks,
D