qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
18 stars 7 forks source link

Definition of qutip.data.sqrtm.add_specialisations( [(JaxArray, JaxArray, sqrtm_jaxarray),] ) raising error #67

Open ArturDomingues opened 4 weeks ago

ArturDomingues commented 4 weeks ago

I was trying to use qutip-jax and got and error while importing it, which is shown below. Going to qutip/core/data it's possible to see that there is no sqrtm in there, but it exists in qutip/core/data/expm.py, with that in mind I think the fix is just changing

qutip.data.sqrtm.add_specialisations(
    [(JaxArray, JaxArray, sqrtm_jaxarray),]
)

to

qutip.data.expm.sqrtm.add_specialisations(
    [(JaxArray, JaxArray, sqrtm_jaxarray),]
)

in unary.py. Here is the error I mentioned:

---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
Cell In[4], [line 3](vscode-notebook-cell:?execution_count=4&line=3)
      [1](vscode-notebook-cell:?execution_count=4&line=1) import jax.numpy as jnp
      [2](vscode-notebook-cell:?execution_count=4&line=2) import qutip
----> [3](vscode-notebook-cell:?execution_count=4&line=3) import qutip_jax
      [5](vscode-notebook-cell:?execution_count=4&line=5) with qutip.CoreOptions(default_dtype="jax"):
      [6](vscode-notebook-cell:?execution_count=4&line=6)     excited = qutip.basis(dim, 4, dtype="jax"), qutip.basis(dim, 3, dtype="jax"), qutip.basis(dim, 5, dtype="jax")

File c:\Users\artur\anaconda3\envs\qutipjax\Lib\site-packages\qutip_jax\__init__.py:33
     [30](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:30) del is_jax_array
     [32](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:32) from .binops import *
---> [33](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:33) from .unary import *
     [34](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:34) from .permute import *
     [35](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/__init__.py:35) from .reshape import *

File c:\Users\artur\anaconda3\envs\qutipjax\Lib\site-packages\qutip_jax\unary.py:195
    [181](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:181) qutip.data.expm.add_specialisations(
    [182](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:182)     [
    [183](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:183)         (JaxArray, JaxArray, expm_jaxarray),
    [184](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:184)     ]
    [185](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:185) )
    [188](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:188) qutip.data.inv.add_specialisations(
    [189](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:189)     [
    [190](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:190)         (JaxArray, JaxArray, inv_jaxarray),
    [191](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:191)     ]
    [192](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:192) )
--> [195](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:195) qutip.data.sqrtm.add_specialisations(
    [196](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:196)     [(JaxArray, JaxArray, sqrtm_jaxarray),]
    [197](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:197) )
    [200](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:200) qutip.data.project.add_specialisations(
    [201](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:201)     [
    [202](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:202)         (JaxArray, JaxArray, project_jaxarray),
    [203](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:203)     ]
    [204](file:///C:/Users/artur/anaconda3/envs/qutipjax/Lib/site-packages/qutip_jax/unary.py:204) )

AttributeError: module 'qutip.core.data' has no attribute 'sqrtm'

I installed qutip-jax following the instructions in this link, so I've used

pip install qutip --pre
pip install git+https://github.com/qutip/qutip-jax.git
Ericgig commented 4 weeks ago

To use the development version of qutip-jax, you will need to install both from source:

pip install git+https://github.com/qutip/qutip.git
pip install git+https://github.com/qutip/qutip-jax.git

If you don't have cython working to compile qutip, you could use the released version:

pip install qutip qutip-jax
ArturDomingues commented 4 weeks ago

Ok, got it, this should be explicit in the installation instructions