jcmgray / quimb

A python library for quantum information and many-body calculations including tensor networks.
http://quimb.readthedocs.io
Other
455 stars 107 forks source link

Automatic JAX-registration of `TensorNetwork` subclasses #150

Closed mofeing closed 1 year ago

mofeing commented 1 year ago

Fixes #148.

This PR makes quimb compatible with JAX routines (i.e. jax.grad, jax.jit, jax.vmap, jax.pmap, ...).

In order not to increase quimb's import time, it recursively descends through TensorNetwork's class hierarchy and registers the child classes in JAX when get_jax is called. If the user defines a class that inherits TensorNetwork after quimb calls get_jax (very edgy case), they can still call the jax_update_register to update the registry.

mofeing commented 1 year ago

@jcmgray There is a conflict on this PR. Mainly jax.BatchTracer crashes at Vectorizer.unpack because it does not have a .shape attribute (you have to call reshape). Could it be posible to call autoray.do("reshape", ...) instead of inplace reshaping?

jcmgray commented 1 year ago

Looks great, thanks!

Regarding reshape, my only question is when is the Vectorizer being called with non-numpy (presumably jax?) arrays? I do think that class should explicitly not be backend agnostic and tailored instead for the scipy minimize interface etc.

If the general functionality is useful for some other purpose then it should probably be a separate (and simpler) object.

mofeing commented 1 year ago

Looks like the error I had (and solved with the autoray.do("reshape", ...) inside Vectorizer.unpack) is no longer present. I guess because I changed the function that I pass to jax.vmap.

Anyway, I've reverted those changes.

codecov[bot] commented 1 year ago

Codecov Report

Merging #150 (ddbccde) into develop (2af4592) will decrease coverage by 0.06%. The diff coverage is 17.39%.

@@             Coverage Diff             @@
##           develop     #150      +/-   ##
===========================================
- Coverage    68.89%   68.82%   -0.07%     
===========================================
  Files           43       43              
  Lines        17318    17341      +23     
===========================================
+ Hits         11931    11935       +4     
- Misses        5387     5406      +19     
Impacted Files Coverage Δ
quimb/tensor/optimize.py 31.03% <17.39%> (-0.44%) :arrow_down:

:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more

jcmgray commented 1 year ago

Nice, thanks!