qutip / qutip-jax

JAX backend for QuTiP
BSD 3-Clause "New" or "Revised" License
15 stars 7 forks source link

Add binary operation specialization #2

Closed Ericgig closed 1 year ago

Ericgig commented 1 year ago

Add binary operations between JaxArray: matmul, add, sub, kron, multiply, and binary operation between JaxArray and scalar: mul, pow.

coveralls commented 1 year ago

Pull Request Test Coverage Report for Build 3203863830


Changes Missing Coverage Covered Lines Changed/Added Lines %
src/qutip_jax/binops.py 44 47 93.62%
<!-- Total: 51 54 94.44% -->
Totals Coverage Status
Change from base Build 3160050711: 2.5%
Covered Lines: 108
Relevant Lines: 118

💛 - Coveralls
Ericgig commented 1 year ago

@quantshah Could you add your new tests in another PR instead of pushing it here. Usage test and adding specialization don't have to be done together.

quantshah commented 1 year ago

Everything looks good here except that we cannot pass QuTiP objects to and from functions that require JIT compilation. I wrote a simple use-case to show this. We probably need to do something like this: https://github.com/google/jax/issues/4269#issuecomment-691402423

But I am not sure how what's the best way to go here.

quantshah commented 1 year ago

Ok I will add the JIT tests in a separate PR. Everything looks good here. Will merge after the tests pass.