Open erikhuck opened 1 year ago
And here's the code for instantiating the model:
model = GraphAttentionNetwork(
device=device, n_atom_features=dataset.n_atom_features, n_bond_features=dataset.n_bond_features, gat_hidden_channels=64,
gat_n_layers=4, gat_out_channels=256, gat_act_func='relu', gat_norm='batch_norm', aggregation='max-add-mean',
gat_dropout=0.1, n_heads=10, leaky_relu_slope=0.2, mlp_hidden_channels=128, mlp_n_layers=3, mlp_act_func='relu',
mlp_norm='batch_norm', mlp_dropout=0.1)
🐛 Describe the bug
Using an LSTM jumping knowledge with the GAT model results in a warning when evaluating the model (not when training on it):
It appears that this can be fixed by calling
flatten_parameters()
on the internal LSTM model.My model:
Prediction code (ran after training for a number of epochs):
Environment