google-deepmind / jraph

A Graph Neural Network Library in Jax
https://jraph.readthedocs.io/en/latest/
Apache License 2.0
1.38k stars 90 forks source link

examples/higgs_detection.py doesn't learn #1

Closed phinate closed 3 years ago

phinate commented 3 years ago

It's a super cool example (particle physicist here!), but i'm not sure if the system actually learns anything. The loss just seems to be fluctuating around random performance on the test set, e.g. 11000 steps with default settings:

I1124 17:51:33.153964 4592291264 higgs_detection.py:204] step 0 loss train 0.5299999713897705 test 0.4699999988079071
I1124 17:51:36.350952 4592291264 higgs_detection.py:204] step 1000 loss train 0.4000000059604645 test 0.41999998688697815
I1124 17:51:39.533082 4592291264 higgs_detection.py:204] step 2000 loss train 0.5099999904632568 test 0.5400000214576721
I1124 17:51:42.694689 4592291264 higgs_detection.py:204] step 3000 loss train 0.4699999988079071 test 0.49000000953674316
I1124 17:51:45.841547 4592291264 higgs_detection.py:204] step 4000 loss train 0.5199999809265137 test 0.49000000953674316
I1124 17:51:48.940982 4592291264 higgs_detection.py:204] step 5000 loss train 0.550000011920929 test 0.41999998688697815
I1124 17:51:52.107565 4592291264 higgs_detection.py:204] step 6000 loss train 0.5099999904632568 test 0.47999998927116394
I1124 17:51:55.312087 4592291264 higgs_detection.py:204] step 7000 loss train 0.5299999713897705 test 0.4699999988079071
I1124 17:51:58.471485 4592291264 higgs_detection.py:204] step 8000 loss train 0.5600000023841858 test 0.550000011920929
I1124 17:52:01.599973 4592291264 higgs_detection.py:204] step 9000 loss train 0.49000000953674316 test 0.550000011920929
I1124 17:52:04.751693 4592291264 higgs_detection.py:204] step 10000 loss train 0.46000000834465027 test 0.47999998927116394
I1124 17:52:07.865066 4592291264 higgs_detection.py:204] step 11000 loss train 0.4699999988079071 test 0.5

Not sure if it was meant to work in practice, or just a nice example of a problem implementation (that part is done very nicely :) )

thomaskeck commented 3 years ago

Thank you for the feedback on this example.

The GNN architecture shown in the higgs_detection example can solve this problem in theory. You can verify this by plugging the analytical solution into the edge update function: See https://github.com/deepmind/jraph/blob/master/jraph/examples/higgs_detection.py#L135

In practice, the simple MLP and SGD algorithm used for the example, does (unfortunately) not converge to the correct solution. Imho, one would need to tune the optimization algorithm (and maybe scale the inputs appropriately) to facilitate convergence.