Open jhossbach opened 1 year ago
In this code, value_and_grad has the ability to both compute the loss and it's gradient at the same time. We can use that and remove the unnecessary second calculation of the loss in _compute_metrics.
value_and_grad
_compute_metrics
https://github.com/zincware/ZnRND/blob/36b921aae1580ee4ec64a36219db77e9f3ad27d9/znrnd/models/jax_model.py#L137-L146
@KonstiNik Did you resolve this at some stage during your restructures?
In this code,
value_and_grad
has the ability to both compute the loss and it's gradient at the same time. We can use that and remove the unnecessary second calculation of the loss in_compute_metrics
.https://github.com/zincware/ZnRND/blob/36b921aae1580ee4ec64a36219db77e9f3ad27d9/znrnd/models/jax_model.py#L137-L146