Closed king-p3nguin closed 4 months ago
Could you try with simplify_sequence=‘R’
? That will turn off all the dynamic shape simplifications that are inherently not compatible with jax.jit.
Changing circ.amplitude("1" * n)
to circ.amplitude("1" * n, simplify_sequence=‘R’)
worked! Thank you.
What happened?
When I used
qtn.Circuit.amplitude()
in a loss function decorated withjax.jit
, it threwNonConcreteBooleanIndexError
.What did you expect to happen?
qtn.Circuit.amplitude()
should be compatible withjax.jit
.Minimal Complete Verifiable Example
Relevant log output
Anything else we need to know?
Changing
https://github.com/jcmgray/quimb/blob/6e522e6bd83f1e65bbee9ca256162c26b2833ae5/quimb/tensor/array_ops.py#L380-L382
to
works, but in this case it throws
TracerBoolConversionError
. Functionfind_antidiag_axes()
is not compatible withjax.jit
because its return value type changes depending on input (NoneType or tuple), and more changes may be necessary.Environment
os: windows wsl (Ubuntu 22.04.4 LTS) Python: 3.11.7 jax: 0.4.25 quimb: https://github.com/jcmgray/quimb/tree/6e522e6bd83f1e65bbee9ca256162c26b2833ae5/