Closed mofeing closed 1 year ago
@jcmgray There is a conflict on this PR. Mainly jax.BatchTracer
crashes at Vectorizer.unpack
because it does not have a .shape
attribute (you have to call reshape
). Could it be posible to call autoray.do("reshape", ...)
instead of inplace reshaping?
Looks great, thanks!
Regarding reshape
, my only question is when is the Vectorizer
being called with non-numpy (presumably jax?) arrays? I do think that class should explicitly not be backend agnostic and tailored instead for the scipy minimize interface etc.
If the general functionality is useful for some other purpose then it should probably be a separate (and simpler) object.
Looks like the error I had (and solved with the autoray.do("reshape", ...)
inside Vectorizer.unpack
) is no longer present. I guess because I changed the function that I pass to jax.vmap
.
Anyway, I've reverted those changes.
Merging #150 (ddbccde) into develop (2af4592) will decrease coverage by
0.06%
. The diff coverage is17.39%
.
@@ Coverage Diff @@
## develop #150 +/- ##
===========================================
- Coverage 68.89% 68.82% -0.07%
===========================================
Files 43 43
Lines 17318 17341 +23
===========================================
+ Hits 11931 11935 +4
- Misses 5387 5406 +19
Impacted Files | Coverage Δ | |
---|---|---|
quimb/tensor/optimize.py | 31.03% <17.39%> (-0.44%) |
:arrow_down: |
:mega: We’re building smart automated test selection to slash your CI/CD build times. Learn more
Nice, thanks!
Fixes #148.
This PR makes
quimb
compatible with JAX routines (i.e.jax.grad
,jax.jit
,jax.vmap
,jax.pmap
, ...).In order not to increase
quimb
's import time, it recursively descends throughTensorNetwork
's class hierarchy and registers the child classes in JAX whenget_jax
is called. If the user defines a class that inheritsTensorNetwork
afterquimb
callsget_jax
(very edgy case), they can still call thejax_update_register
to update the registry.