PennyLaneAI / comments

0 stars 0 forks source link

qml/demos/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax/ #26

Open utterances-bot opened 2 weeks ago

utterances-bot commented 2 weeks ago

How to optimize a QML model using JAX and Optax | PennyLane Demos

Learn how to train a quantum machine learning model using PennyLane, JAX, and Optax.

https://pennylane.ai/qml/demos/tutorial_How_to_optimize_QML_model_using_JAX_and_Optax/

supreethmv commented 2 weeks ago

Hi, for tracking the optimization process and avoid over-fitting the train set, how would you track also the test loss and accuracy at each update step. I tried with callback, but the callback (or debug) functionality available in JAX explicitly suggests that we shouldn't be doing any compute intensive tasks, like finding the test loss and accuracy. I can get the performance on the test set only at the end of the training, and there's no way I can get the test losses in the middle of the training or is there a workaround for this?

CatalinaAlbornoz commented 2 weeks ago

Hi @supreethmv, You can use jax.value_and_grad with your validation data and targets. Eg.

loss_val_test, grads_test = jax.value_and_grad(loss_fn)(params, test_data, test_targets)

You can use it within the for loop if you're not jitting, or within update_step_jit if you are.

If you have further code-related questions feel free to post them in the PennyLane Discussion Forum! We have a section for questions related to demos.