Closed bonfab closed 3 years ago
Hi @bonfab! Thank you for reporting this bug. We'll get on it and try to fix it!
One approach to fix this could be to make sure that the qml.math.linalg.norm
and qml.math.allclose
functions both work with the JAX jit --- once this is the case, we can modify this default.qubit
method to use these functions instead
I realized it might be not as trivial to fix as first thought. Even after adapting the source code to jax.numpy
one still receives a ConcretizationTypeError
. Only workaround working for me at the moment is to comment out the check completely.
Yes @bonfab, I was also getting an error. I think Josh's approach is a good way to go. If you have any other ideas on how to fix this bug let us know here!
Hi @bonfab, with #1683 merged, this should be resolved in the master
branch.
Expected behavior
The
_apply_state_vector
method seems to not have been adapted to be compatible with jax compiled code when setting a state vector withqml.QubitStateVector
.Actual behavior
Specifically in
where the norm of the state vector is calculated with
np.linalg.norm
raises ajax._src.errors.TracerArrayConversionError
.A solution could be to use the jax.numpy version instead:
jnp.linalg.norm
Additional information
No response
Source code
Tracebacks
No response
System information