qiskit-community / qiskit-dynamics

Tools for building and solving models of quantum systems in Qiskit
https://qiskit-community.github.io/qiskit-dynamics/
Apache License 2.0
103 stars 61 forks source link

Arraylias integration - Signal class - #269

Closed to24toro closed 10 months ago

to24toro commented 11 months ago

Summary

I started to integrate arraylias to dynamics. This PR includes the integration of the Signal class in dynamics.

Details and comments

to24toro commented 11 months ago

The current implementation fails the tutorial in optimizing_pulse_sequence.rst. The main reason for this failure is that numpy.array and jax.numpy.array sometimes gets in the mix in the code. This would be no problem if we don't use jax.jit. However, this failure is found within jit(value_and_grad(objective)) in this tutorial.

if isinstance(signal, DiscreteSignal):
  # Perform a discrete time convolution.
  dt = signal.dt
  func_samples = np.asarray([self._func(dt * i) for i in range(signal.duration)])
  func_samples = func_samples / sum(func_samples)
  sig_samples = signal(dt * np.arange(signal.duration))

even if np -> unp is changed, unp.arange(signal.duration) becomes a numpy.array because signal.duration is assumed to be an int. The point is that types like int or float get coverted to numpy by unp.arange(). In the other words, numpy takes precedence over jax, leading to the issue.

DanPuzzuoli commented 11 months ago

I'm trying to figure out how to get the tests to pass in other submodules - the issue with a lot of the failures seems to be that:

unp.asarray(Array(x, backend='jax'))

ends up calling numpy.array(x).

This issue comes up in tests like test.dynamics.models.test_generator_model.TestGeneratorModelSparseJax.test_jit_grad in which a value a is being traced, and the function contains the code:

Signal(Array(a))

This test (and I'm guessing many others) can be fixed by simply changing the above to:

Signal(a)

This is how we'll want to do things going forward anyway, but this doesn't help with the fact that Signal(Array(x, backend='jax')) is broken when x is a tracer.

Fixing this by modifying Array appears to be nontrivial. I haven't been able to find a quick fix as numpy gets upset when np.array returns something that is not an array. As such, it seems that maybe asarray cannot work how we'd want for Array with backend='jax'.

DanPuzzuoli commented 11 months ago

I've figured out how to bypass the above issue and directly fix most tests here. The only thing I haven't fixed yet is stuff to do with pulse -> signal conversion and JAX compatibility - but I think it may just be the same issue.

I'll attempt again to see if we can "properly" fix this problem (i.e. get unp.asarray to work with Array with backend=="jax"). It would be nice to have backwards compatibility with code that uses Array, but I'm also okay with it breaking if we can't find a clean solution.

to24toro commented 10 months ago

The tests which have not been passed are two types: