jcmgray / quimb

A python library for quantum information and many-body calculations including tensor networks.
http://quimb.readthedocs.io
Other
455 stars 107 forks source link

`qtn.Circuit.amplitude()` does not work with `jax.jit` #220

Closed king-p3nguin closed 4 months ago

king-p3nguin commented 4 months ago

What happened?

When I used qtn.Circuit.amplitude() in a loss function decorated with jax.jit, it threw NonConcreteBooleanIndexError.

What did you expect to happen?

qtn.Circuit.amplitude() should be compatible with jax.jit.

Minimal Complete Verifiable Example

import quimb.tensor as qtn
import jax
import jax.numpy as jnp
from jax.example_libraries import optimizers

def two_local(n, depth, parameters):
    assert parameters.shape == (depth, n)
    circ = qtn.Circuit(n)

    for r in range(depth - 1):
        for i in range(n):
            circ.rz(parameters[r, i], i)
        for i in range(n - 1):
            circ.cx(i, i + 1)

    for i in range(n):
        circ.rz(parameters[r, i], i)

    return circ

n = 6
depth = 9
key = jax.random.PRNGKey(42)
parameters = jax.random.normal(key, (depth, n))

def loss_fn(p):
    circ = two_local(n, depth, p)
    return 1.0 - jnp.abs(circ.amplitude("1" * n)) ** 2

loss_grad_fn = jax.value_and_grad(loss_fn)

(
    initial_opt_func,
    opt_update_func,
    get_new_params_from_state_func,
) = optimizers.adam(step_size=0.001)

@jax.jit
def update(step, opt_state):
    params = get_new_params_from_state_func(opt_state)
    value, grads = loss_grad_fn(params)
    opt_state = opt_update_func(step, grads, opt_state)
    return params, value, opt_state

opt_state = initial_opt_func(parameters)

for step in range(1000):
    param, value, optimized_state = update(step, opt_state)
    if step % 100 == 0:
        print(f"Step: {step}, Probability: {1-value}")

Relevant log output

{
    "name": "NonConcreteBooleanIndexError",
    "message": "Array boolean indices must be concrete; got ShapedArray(bool[2,2])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError",
    "stack": "---------------------------------------------------------------------------
NonConcreteBooleanIndexError              Traceback (most recent call last)
Cell In[1], line 54
     51 opt_state = initial_opt_func(parameters)
     53 for step in range(1000):
---> 54     param, value, optimized_state = update(step, opt_state)
     55     if step % 100 == 0:
     56         print(f\"Step: {step}, Probability: {1-value}\")

    [... skipping hidden 12 frame]

Cell In[1], line 46, in update(step, opt_state)
     43 @jax.jit
     44 def update(step, opt_state):
     45     params = get_new_params_from_state_func(opt_state)
---> 46     value, grads = loss_grad_fn(params)
     47     opt_state = opt_update_func(step, grads, opt_state)
     48     return params, value, opt_state

    [... skipping hidden 8 frame]

Cell In[1], line 31, in loss_fn(p)
     29 def loss_fn(p):
     30     circ = two_local(n, depth, p)
---> 31     return 1.0 - jnp.abs(circ.amplitude(\"1\" * n)) ** 2

File ~/quimb/quimb/tensor/circuit.py:2329, in Circuit.amplitude(self, b, optimize, simplify_sequence, simplify_atol, simplify_equalize_norms, backend, dtype, rehearse)
   2322 fs_opts = {
   2323     \"seq\": simplify_sequence,
   2324     \"atol\": simplify_atol,
   2325     \"equalize_norms\": simplify_equalize_norms,
   2326 }
   2328 # get the full wavefunction simplified
-> 2329 psi_b = self.get_psi_simplified(**fs_opts)
   2331 # fix the output indices to the correct bitstring
   2332 for i, x in zip(range(self.N), b):

File ~/quimb/quimb/tensor/circuit.py:2198, in Circuit.get_psi_simplified(self, seq, atol, equalize_norms)
   2195 output_inds = tuple(map(psi.site_ind, range(self.N)))
   2197 # simplify the state and cache it
-> 2198 psi.full_simplify_(
   2199     seq=seq,
   2200     atol=atol,
   2201     output_inds=output_inds,
   2202     equalize_norms=equalize_norms,
   2203 )
   2204 self._storage[key] = psi
   2206 # return a copy so we can modify it inplace

File ~/quimb/quimb/tensor/tensor_core.py:10498, in TensorNetwork.full_simplify(self, seq, output_inds, atol, equalize_norms, cache, inplace, progbar, rank_simplify_opts, loop_simplify_opts, split_simplify_opts, custom_methods, split_method)
  10491     tn.rank_simplify_(
  10492         output_inds=ix_o,
  10493         cache=cache,
  10494         equalize_norms=equalize_norms,
  10495         **rank_simplify_opts,
  10496     )
  10497 elif meth == \"A\":
> 10498     tn.antidiag_gauge_(
  10499         output_inds=ix_o, atol=atol, cache=cache
  10500     )
  10501 elif meth == \"C\":
  10502     tn.column_reduce_(output_inds=ix_o, atol=atol, cache=cache)

File ~/quimb/quimb/tensor/tensor_core.py:9720, in TensorNetwork.antidiag_gauge(self, output_inds, atol, cache, inplace)
   9717 if cache_key in cache:
   9718     continue
-> 9720 ij = find_antidiag_axes(t.data, atol=atol)
   9722 # tensor not anti-diagonal
   9723 if ij is None:

File ~/quimb/quimb/tensor/array_ops.py:380, in find_antidiag_axes(x, atol)
    378     if di != dj:
    379         continue
--> 380     if do('allclose', x[indxrs[i] != dj - 1 - indxrs[j]], 0.0,
    381           atol=atol, like=backend):
    382         return (i, j)
    383 return None

File ~/.local/share/virtualenvs/quimb-au0GDUdI/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:736, in _forward_operator_to_aval.<locals>.op(self, *args)
    735 def op(self, *args):
--> 736   return getattr(self.aval, f\"_{name}\")(self, *args)

File ~/.local/share/virtualenvs/quimb-au0GDUdI/lib/python3.11/site-packages/jax/_src/numpy/array_methods.py:349, in _getitem(self, item)
    348 def _getitem(self, item):
--> 349   return lax_numpy._rewriting_take(self, item)

File ~/.local/share/virtualenvs/quimb-au0GDUdI/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4589, in _rewriting_take(arr, idx, indices_are_sorted, unique_indices, mode, fill_value)
   4583     if (isinstance(aval, core.DShapedArray) and aval.shape == () and
   4584         dtypes.issubdtype(aval.dtype, np.integer) and
   4585         not dtypes.issubdtype(aval.dtype, dtypes.bool_) and
   4586         isinstance(arr.shape[0], int)):
   4587       return lax.dynamic_index_in_dim(arr, idx, keepdims=False)
-> 4589 treedef, static_idx, dynamic_idx = _split_index_for_jit(idx, arr.shape)
   4590 return _gather(arr, treedef, static_idx, dynamic_idx, indices_are_sorted,
   4591                unique_indices, mode, fill_value)

File ~/.local/share/virtualenvs/quimb-au0GDUdI/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4674, in _split_index_for_jit(idx, shape)
   4670   raise TypeError(f\"JAX does not support string indexing; got {idx=}\")
   4672 # Expand any (concrete) boolean indices. We can then use advanced integer
   4673 # indexing logic to handle them.
-> 4674 idx = _expand_bool_indices(idx, shape)
   4676 leaves, treedef = tree_flatten(idx)
   4677 dynamic = [None] * len(leaves)

File ~/.local/share/virtualenvs/quimb-au0GDUdI/lib/python3.11/site-packages/jax/_src/numpy/lax_numpy.py:4972, in _expand_bool_indices(idx, shape)
   4968   abstract_i = core.get_aval(i)
   4970 if not type(abstract_i) is ConcreteArray:
   4971   # TODO(mattjj): improve this error by tracking _why_ the indices are not concrete
-> 4972   raise errors.NonConcreteBooleanIndexError(abstract_i)
   4973 elif _ndim(i) == 0:
   4974   out.append(bool(i))

NonConcreteBooleanIndexError: Array boolean indices must be concrete; got ShapedArray(bool[2,2])

See https://jax.readthedocs.io/en/latest/errors.html#jax.errors.NonConcreteBooleanIndexError"
}

Anything else we need to know?

Changing

https://github.com/jcmgray/quimb/blob/6e522e6bd83f1e65bbee9ca256162c26b2833ae5/quimb/tensor/array_ops.py#L380-L382

to

        if do('allclose', do('where', indxrs[i] != dj - 1 - indxrs[j], x, 0.0), 0.0,
              atol=atol, like=backend):
            return (i, j)

works, but in this case it throws TracerBoolConversionError. Function find_antidiag_axes() is not compatible with jax.jit because its return value type changes depending on input (NoneType or tuple), and more changes may be necessary.

Environment

os: windows wsl (Ubuntu 22.04.4 LTS) Python: 3.11.7 jax: 0.4.25 quimb: https://github.com/jcmgray/quimb/tree/6e522e6bd83f1e65bbee9ca256162c26b2833ae5/

jcmgray commented 4 months ago

Could you try with simplify_sequence=‘R’? That will turn off all the dynamic shape simplifications that are inherently not compatible with jax.jit.

king-p3nguin commented 4 months ago

Changing circ.amplitude("1" * n) to circ.amplitude("1" * n, simplify_sequence=‘R’) worked! Thank you.