Closed to24toro closed 1 year ago
Thanks for the PR @to24toro !
I think the code you've added to Dynamics looks good - however there is a core problem right now that prevents this from being useful. Your test:
def jit_func(amp, sigma):
return get_samples(Gaussian(duration=5, amp=amp, sigma=sigma))
jit_samples = jax.jit(jit_func, static_argnums=(0, 1))(0.983, 2)
self.assertAllClose(jit_samples, self.gauss_get_waveform_samples, atol=1e-7, rtol=1e-7)
sets static_argnums=(0, 1)
, however, for this to be useful, this test should work without usingstatic_argnums
. I.e. you should just be able to do:
jit_samples = jax.jit(jit_func)(0.983, 2)
If you try to do this now with the current main
branch of terra, you get a JAX error, coming from the line:
...
File /opt/anaconda3/envs/devEnv310/lib/python3.10/site-packages/qiskit/pulse/library/symbolic_pulses.py:436, in SymbolicPulse.__init__(self, pulse_type, duration, parameters, name, limit_amplitude, envelope, constraints, valid_amp_conditions)
431 # TODO remove this.
432 # This is due to convention in IBM Quantum backends where "amp" is treated as a
433 # special parameter that must be defined in the form [real, imaginary].
434 # this check must be removed because Qiskit pulse should be backend agnostic.
435 if "amp" in parameters and not isinstance(parameters["amp"], ParameterExpression):
--> 436 parameters["amp"] = complex(parameters["amp"])
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `complex` function. If trying to convert the data type of a value, try using `x.astype(complex)` or `jnp.array(x, complex)` instead.
The error occurred while tracing the function jit_func at /var/folders/v5/5t1xdchn2ws5l1nj9h02vm0m0000gn/T/ipykernel_3597/1345434633.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'amp'.
which is because in this case, parameters["amp"]
is a JAX tracer, and calling complex(x)
for x
a JAX tracer type will cause this ConcretizationTypeError
.
With this error being raised, there is unfortunately no benefit to having get_samples
execute using JAX - the only benefit of JAX is being able to jit
/grad
/etc. If we can't truly jit
a computation, then there is no reason to use JAX over numpy. Obviously, since this error is being raised by terra, it will be necessary to make changes to terra before this can work.
To move forward with this PR, I think it makes sense to figure out whatever changes are necessary in terra to make the above test case work with the static_argnums
call kwarg removed, make those changes, then merge this PR to dynamics once those are released in terra. @nkanazawa1989 any comments on this thought process?
The special casing of amp must be removed from terra (only duration and amp are the non-float and unfortunately Python is not typed). Currently https://github.com/Qiskit/qiskit-terra/pull/9002 is trying to introduce (amp, float) representation to symbolic pulses, and this makes all parameters float-type except for duration. This approach allows us to remove builtin typecasting and we can eventually remove static_arguments
.
Feel free to on hold this until above PR is merged, or you can merge this PR without JIT test and later add the the test in a separate (follow up) PR.
Okay cool thanks.
Feel free to on hold this until above PR is merged, or you can merge this PR without JIT test and later add the the test in a separate (follow up) PR.
I think it makes sense to hold on this PR, as I don't think the required functionality can truly be verified until terra is at a point where the static_argnums
can be removed.
Fair enough. Let's merge https://github.com/Qiskit/qiskit-terra/pull/9002 first.
Hey @to24toro , there are two more tests I think it'd be good to add:
jit
test, but change it so that at the end, instead of jit
, you do jax.jit(jax.grad(jit_func))(0.1)
. You'll have to change the the function so that it returns a real scalar (as grad
only works on real scalars), so you could just change what the function returns to be get_samples(instance).sum().real
. Testing this combination of gradding/jitting is important because it can introduce a bit more complexity into what JAX is doing, and sometimes raise errors where jit
/grad
don't raise them on their own.Solver
that compiles/differentiates a simulation of a pulse schedule with symbolic pulses. In the test/dynamics/solvers/test_solver_classes.py
file there is a test class called TestPulseSimulationJAX
. Can you add a test to this class that defines a function building a schedule with symbolic pulses, simulating it with a Solver
, and returning the final state, and attempts to compile/differentiate it? This may be a little redundant with the other tests but it will be a good integration test.One last point: the TestJaxBase
class has two helper functions: jit_wrap
and jit_grad_wrap
, which can be used to jit a function, or to jit(grad(func))
a function func
, where the latter takes the sum and the real part of the output. These may eb helpful/convenient.
I modified two points at https://github.com/Qiskit/qiskit-dynamics/pull/149/commits/769e342a06f95d8232bd6475da3e7b46c80717d5 for JAX-jitting.
I've removed the "on hold" label, as you pointed out that terra has been sufficiently updated. Once the errors are resolved and I re-review we can merge this!
My author and commit name were not correct. So I am sorry to have modify them and force-push to pass license/cla.
Summary
This PR adds
get_samples
toInstructionToSignals
for JAX-jitting when using qiskit-pulse and removes the usage ofget_waveform
method ofSymbolicPulse
.Details and comments
get_samples
function gets the envelope expression formSymbolicPulse
and calls sympy.lambdify with numerical backend specified by Array class. The lambdified function is lru cached for performance.