qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
106 stars 60 forks source link

Add get samples function to InstructionToSignals for JAX-jit usage #149

Closed to24toro closed 1 year ago

to24toro commented 2 years ago

Summary

This PR adds get_samples to InstructionToSignals for JAX-jitting when using qiskit-pulse and removes the usage of get_waveform method of SymbolicPulse.

Details and comments

get_samples function gets the envelope expression formSymbolicPulse and calls sympy.lambdify with numerical backend specified by Array class. The lambdified function is lru cached for performance.

CLAassistant commented 2 years ago

CLA assistant check
All committers have signed the CLA.

DanPuzzuoli commented 2 years ago

Thanks for the PR @to24toro !

I think the code you've added to Dynamics looks good - however there is a core problem right now that prevents this from being useful. Your test:

def jit_func(amp, sigma):
    return get_samples(Gaussian(duration=5, amp=amp, sigma=sigma))

jit_samples = jax.jit(jit_func, static_argnums=(0, 1))(0.983, 2)
self.assertAllClose(jit_samples, self.gauss_get_waveform_samples, atol=1e-7, rtol=1e-7)

sets static_argnums=(0, 1), however, for this to be useful, this test should work without usingstatic_argnums. I.e. you should just be able to do:

jit_samples = jax.jit(jit_func)(0.983, 2)

If you try to do this now with the current main branch of terra, you get a JAX error, coming from the line:

...
File /opt/anaconda3/envs/devEnv310/lib/python3.10/site-packages/qiskit/pulse/library/symbolic_pulses.py:436, in SymbolicPulse.__init__(self, pulse_type, duration, parameters, name, limit_amplitude, envelope, constraints, valid_amp_conditions)
    431 # TODO remove this.
    432 #  This is due to convention in IBM Quantum backends where "amp" is treated as a
    433 #  special parameter that must be defined in the form [real, imaginary].
    434 #  this check must be removed because Qiskit pulse should be backend agnostic.
    435 if "amp" in parameters and not isinstance(parameters["amp"], ParameterExpression):
--> 436     parameters["amp"] = complex(parameters["amp"])
...
ConcretizationTypeError: Abstract tracer value encountered where concrete value is expected: Traced<ShapedArray(float64[], weak_type=True)>with<DynamicJaxprTrace(level=0/1)>
The problem arose with the `complex` function. If trying to convert the data type of a value, try using `x.astype(complex)` or `jnp.array(x, complex)` instead.
The error occurred while tracing the function jit_func at /var/folders/v5/5t1xdchn2ws5l1nj9h02vm0m0000gn/T/ipykernel_3597/1345434633.py:1 for jit. This concrete value was not available in Python because it depends on the value of the argument 'amp'.

which is because in this case, parameters["amp"] is a JAX tracer, and calling complex(x) for x a JAX tracer type will cause this ConcretizationTypeError.

With this error being raised, there is unfortunately no benefit to having get_samples execute using JAX - the only benefit of JAX is being able to jit/grad/etc. If we can't truly jit a computation, then there is no reason to use JAX over numpy. Obviously, since this error is being raised by terra, it will be necessary to make changes to terra before this can work.

To move forward with this PR, I think it makes sense to figure out whatever changes are necessary in terra to make the above test case work with the static_argnums call kwarg removed, make those changes, then merge this PR to dynamics once those are released in terra. @nkanazawa1989 any comments on this thought process?

nkanazawa1989 commented 2 years ago

The special casing of amp must be removed from terra (only duration and amp are the non-float and unfortunately Python is not typed). Currently https://github.com/Qiskit/qiskit-terra/pull/9002 is trying to introduce (amp, float) representation to symbolic pulses, and this makes all parameters float-type except for duration. This approach allows us to remove builtin typecasting and we can eventually remove static_arguments.

Feel free to on hold this until above PR is merged, or you can merge this PR without JIT test and later add the the test in a separate (follow up) PR.

DanPuzzuoli commented 2 years ago

Okay cool thanks.

Feel free to on hold this until above PR is merged, or you can merge this PR without JIT test and later add the the test in a separate (follow up) PR.

I think it makes sense to hold on this PR, as I don't think the required functionality can truly be verified until terra is at a point where the static_argnums can be removed.

nkanazawa1989 commented 2 years ago

Fair enough. Let's merge https://github.com/Qiskit/qiskit-terra/pull/9002 first.

DanPuzzuoli commented 1 year ago

Hey @to24toro , there are two more tests I think it'd be good to add:

One last point: the TestJaxBase class has two helper functions: jit_wrap and jit_grad_wrap, which can be used to jit a function, or to jit(grad(func)) a function func, where the latter takes the sum and the real part of the output. These may eb helpful/convenient.

to24toro commented 1 year ago

I modified two points at https://github.com/Qiskit/qiskit-dynamics/pull/149/commits/769e342a06f95d8232bd6475da3e7b46c80717d5 for JAX-jitting.

DanPuzzuoli commented 1 year ago

I've removed the "on hold" label, as you pointed out that terra has been sufficiently updated. Once the errors are resolved and I re-review we can merge this!

to24toro commented 1 year ago

My author and commit name were not correct. So I am sorry to have modify them and force-push to pass license/cla.