microsoft / tf2-gnn

TensorFlow 2 library implementing Graph Neural Networks
MIT License
369 stars 73 forks source link

fix(graph_task_model): remove epoch loss scaling with number of samples #46

Closed megstanley closed 3 years ago

megstanley commented 3 years ago

Previously in graph_task_model.run_one_epoch() we were performing (l_1/N_1 + l_2/N_2 + ...) 1/(N_1 + N_2 + ...), when we wish to perform (l_1/N_1 N_1 + l_2/N_2 N_2 + ...) 1/(N_1 + N_2 + ...). Arises because compute_task_metrics must always return loss per sample for gradient.