Closed phinate closed 3 years ago
Thank you for the feedback on this example.
The GNN architecture shown in the higgs_detection example can solve this problem in theory. You can verify this by plugging the analytical solution into the edge update function: See https://github.com/deepmind/jraph/blob/master/jraph/examples/higgs_detection.py#L135
In practice, the simple MLP and SGD algorithm used for the example, does (unfortunately) not converge to the correct solution. Imho, one would need to tune the optimization algorithm (and maybe scale the inputs appropriately) to facilitate convergence.
It's a super cool example (particle physicist here!), but i'm not sure if the system actually learns anything. The loss just seems to be fluctuating around random performance on the test set, e.g. 11000 steps with default settings:
Not sure if it was meant to work in practice, or just a nice example of a problem implementation (that part is done very nicely :) )