danielward27 / flowjax

https://danielward27.github.io/flowjax/
MIT License
85 stars 10 forks source link

Loss function in train utils isn't "jitted" #35

Closed Tennessee-Wallaceh closed 1 year ago

Tennessee-Wallaceh commented 1 year ago

The loss function defined on train_utils line 44 doesn't have jit applied to it.

This is then used in the computation of validation loss, which could significantly slow training particularly when the validation set is large.

A fix should be as simple as adding a eqx.filter_jit decoration to the loss.

danielward27 commented 1 year ago

Thanks! I missed that and including it gives a good speed up.