Open utterances-bot opened 2 weeks ago
Hi, for tracking the optimization process and avoid over-fitting the train set, how would you track also the test loss and accuracy at each update step. I tried with callback, but the callback (or debug) functionality available in JAX explicitly suggests that we shouldn't be doing any compute intensive tasks, like finding the test loss and accuracy. I can get the performance on the test set only at the end of the training, and there's no way I can get the test losses in the middle of the training or is there a workaround for this?
Hi @supreethmv,
You can use jax.value_and_grad
with your validation data and targets. Eg.
loss_val_test, grads_test = jax.value_and_grad(loss_fn)(params, test_data, test_targets)
You can use it within the for
loop if you're not jitting, or within update_step_jit
if you are.
If you have further code-related questions feel free to post them in the PennyLane Discussion Forum! We have a section for questions related to demos.
How to optimize a QML model using JAX and Optax | PennyLane Demos
Learn how to train a quantum machine learning model using PennyLane, JAX, and Optax.
https://pennylane.ai/qml/demos/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax/