Closed to24toro closed 1 year ago
The current implementation fails the tutorial in optimizing_pulse_sequence.rst. The main reason for this failure is that numpy.array
and jax.numpy.array
sometimes gets in the mix in the code. This would be no problem if we don't use jax.jit
. However, this failure is found within jit(value_and_grad(objective))
in this tutorial.
In the existing tutorial code, dynamics.Array
is used, but for example, even if dynamics.Array
backend is jax
, it will be converted to numpy.array
by unp.asarray()
. I think this is at https://github.com/Qiskit-Extensions/qiskit-dynamics/blob/3a373cfb834cd2f2e15245ab924aa3568b8db326/qiskit_dynamics/array/array.py#L188 .
If we are to continue supporting dynamics.Array
by the completion of arraylias integration, we will need to change or improve the dynamics.Array
.
Suppose we can change samples
used in DIscreteSignal
to jax.numpy.array
through unp.asarray
. In that case, an error will occur if there is a numpy.array
elsewhere. Moreover, taking Convolution
as an example, in the following code:
if isinstance(signal, DiscreteSignal):
# Perform a discrete time convolution.
dt = signal.dt
func_samples = np.asarray([self._func(dt * i) for i in range(signal.duration)])
func_samples = func_samples / sum(func_samples)
sig_samples = signal(dt * np.arange(signal.duration))
even if np -> unp is changed, unp.arange(signal.duration)
becomes a numpy.array
because signal.duration
is assumed to be an int
. The point is that types like int
or float
get coverted to numpy
by unp.arange()
. In the other words, numpy
takes precedence over jax
, leading to the issue.
I'm trying to figure out how to get the tests to pass in other submodules - the issue with a lot of the failures seems to be that:
unp.asarray(Array(x, backend='jax'))
ends up calling numpy.array(x)
.
This issue comes up in tests like test.dynamics.models.test_generator_model.TestGeneratorModelSparseJax.test_jit_grad
in which a value a
is being traced, and the function contains the code:
Signal(Array(a))
This test (and I'm guessing many others) can be fixed by simply changing the above to:
Signal(a)
This is how we'll want to do things going forward anyway, but this doesn't help with the fact that Signal(Array(x, backend='jax'))
is broken when x
is a tracer.
Fixing this by modifying Array
appears to be nontrivial. I haven't been able to find a quick fix as numpy
gets upset when np.array
returns something that is not an array
. As such, it seems that maybe asarray
cannot work how we'd want for Array
with backend='jax'
.
I've figured out how to bypass the above issue and directly fix most tests here. The only thing I haven't fixed yet is stuff to do with pulse -> signal conversion and JAX compatibility - but I think it may just be the same issue.
I'll attempt again to see if we can "properly" fix this problem (i.e. get unp.asarray
to work with Array
with backend=="jax"
). It would be nice to have backwards compatibility with code that uses Array
, but I'm also okay with it breaking if we can't find a clean solution.
The tests which have not been passed are two types:
list
is converted to numpy.array
. The solution is implementing setting something like Array.default_backend() and when it is "jax", we need to pass jnp.array([]) to samples
.np.append(ham_sig_vals, dis_sig_vals, axis=-1)
in "qiskit_dynamics/models/operator_collections.py". Should it be solved when working on operator_collections?
Summary
I started to integrate arraylias to dynamics. This PR includes the integration of the Signal class in dynamics.
Details and comments