qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

`Qobj.__mul__` not compatible with `jit` #37

Open Ericgig opened 3 months ago

Ericgig commented 3 months ago

Input

op = qutip.num(3, dtype="jax")

@jax.jit
def f(op, A):
    return op * A

f(op, 2.)

Output:

---------------------------------------------------------------------------

ValueError                                Traceback (most recent call last)

[<ipython-input-69-1460d0a528c8>](https://localhost:8080/#) in <cell line: 7>()
      5     return op * A
      6 
----> 7 f(op, 2.)

    [... skipping hidden 12 frame]

6 frames

    [... skipping hidden 2 frame]

    [... skipping hidden 5 frame]

[/usr/local/lib/python3.10/dist-packages/qutip/core/qobj.py](https://localhost:8080/#) in _initialize_data(self, arg, dims, copy)
    298             )
    299         if self._dims.shape != self._data.shape:
--> 300             raise ValueError('Provided dimensions do not match the data: ' +
    301                              f"{self._dims.shape} vs {self._data.shape}")
    302 

ValueError: Provided dimensions do not match the data: (3, 3) vs (1, 1)
rochisha0 commented 2 weeks ago

Hi @Ericgig ! Have you been able to find the reason for this particular issue?

Ericgig commented 2 weeks ago

No. I guess is JaxArray(mul_jax(self.data, other)) is seen as a 1x1 matrix when tracer are used instead of real array. But I don't know why.