dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
822 stars 67 forks source link

`contract` failed to find required methods from `jax.numpy` backend #224

Closed zazbone closed 1 month ago

zazbone commented 4 months ago

Hi !

I'm trying to use opt_einsum with jax, but it seems like it fail to find the correct methods from jax module. Resulting into a conversion from jax Array to ndarray or a traceback if jax is specified as backend.

This piece of code reproduce the error behavior.

import sys

import jax
import jax.numpy as jnp
import opt_einsum

print(sys.version)
print(jax.__version__)
print(opt_einsum.__version__)
jnp.startswith

x = jnp.linspace(0, 1, 32)
print(type(opt_einsum.contract("i->i", x)))
print(type(opt_einsum.contract("i->i", x, backend=jnp)))
issue_report$ python main.py 
3.10.13 (main, Dec 15 2023, 19:01:59) [GCC 11.4.0]
0.4.24
v3.3.0
<class 'numpy.ndarray'>
Traceback (most recent call last):
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 65, in get_func
    return _cached_funcs[func, backend]
KeyError: ('einsum', <module 'jax.numpy' from '/[...]/issue_report/venv/lib/python3.10/site-packages/jax/numpy/__init__.py'>)

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 38, in _import_func
    lib = importlib.import_module(_aliases.get(backend, backend))
  File "/[...]/.pyenv/versions/3.10.13/lib/python3.10/importlib/__init__.py", line 117, in import_module
    if name.startswith('.'):
AttributeError: module 'jax.numpy' has no attribute 'startswith'

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/[...]/issue_report/main.py", line 14, in <module>
    print(type(opt_einsum.contract("i->i", x, backend=jnp)))
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 507, in contract
    return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 591, in _core_contract
    new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/sharing.py", line 151, in cached_einsum
    return einsum(*args, **kwargs)
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/contract.py", line 337, in _einsum
    fn = backends.get_func('einsum', kwargs.pop('backend', 'numpy'))
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 67, in get_func
    fn = _import_func(func, backend, default)
  File "/[...]/issue_report/venv/lib/python3.10/site-packages/opt_einsum/backends/dispatch.py", line 44, in _import_func
    raise AttributeError(error_msg.format(backend, func))
AttributeError: <module 'jax.numpy' from '/[...]/issue_report/venv/lib/python3.10/site-packages/jax/numpy/__init__.py'> doesn't seem to provide the function einsum - see https://optimized-einsum.readthedocs.io/en/latest/backends.html for details on which functions are required for which contractions.

I've just strip in [...] personal folder information.

Thanks and best regards.

jcmgray commented 4 months ago

Looks like the the backend dispatch aliases need updating with "jaxlib": "jax.numpy" for it to work automatically. Note passing the module directly is not supported, instead if you call with backend="jax" it should work.

zazbone commented 4 months ago

Hi, Indeed, it works better that way, I should have been more attentive Thanks a lot

dgasmith commented 2 months ago

@jcmgray Could you make an associated PR if you have a minute?

dgasmith commented 1 month ago

Should be closed by https://github.com/dgasmith/opt_einsum/pull/228.