Closed Tennessee-Wallaceh closed 1 year ago
The loss function defined on train_utils line 44 doesn't have jit applied to it.
loss
train_utils
This is then used in the computation of validation loss, which could significantly slow training particularly when the validation set is large.
A fix should be as simple as adding a eqx.filter_jit decoration to the loss.
eqx.filter_jit
Thanks! I missed that and including it gives a good speed up.
The
loss
function defined ontrain_utils
line 44 doesn't have jit applied to it.This is then used in the computation of validation loss, which could significantly slow training particularly when the validation set is large.
A fix should be as simple as adding a
eqx.filter_jit
decoration to the loss.