neo4j / graph-data-science

Source code for the Neo4j Graph Data Science library of graph algorithms.
https://neo4j.com/docs/graph-data-science/current/
Other
596 stars 157 forks source link

[GraphSage] Allow "patience" with "tolerance" when training #252

Open giahung24 opened 1 year ago

giahung24 commented 1 year ago

Is your feature request related to a problem? Please describe.

I had a training round stopped just after 2 iterations where the loss did not reduce (loss delta < tolerance). This is not very efficient as "early stopping" strategy.

Describe the solution you would like

It should be patience a little more, like checking if the loss doesn't reduce after n iterations, hence introduced a new hyper-param 'patience': number of ite with no improvement after which training will be stopped. Currently, this 'patience' is just 1.

giahung24 commented 1 year ago

image for example : 10 minute of batches preparing then algo exits after 2 iterations...

FlorentinD commented 1 year ago

Hey @giahung24 , thanks for your feedback!

Adding a patience parameter sounds like a great idea. Actually we already have most of the logic in place for our Logistic Regression training where we already expose an patience parameter.

We will add it to our backlog and inform you when we added this feature. For now, you could try to increase the tolerance as a workaround.

Alternatively, would you be interested in implementing this yourself? I think it could suit as a good first issue. The main logic is inside org.neo4j.gds.embeddings.graphsage.GraphSageModelTrainer#trainEpoch. And the patience is could be used through org.neo4j.gds.ml.gradientdescent.StreakStopper.