dgasmith / opt_einsum

⚡️Optimizing einsum functions in NumPy, Tensorflow, Dask, and more with contraction order optimization.
https://dgasmith.github.io/opt_einsum/
MIT License
822 stars 67 forks source link

Raise error in `parse_einsum_input` when output subscript is specified multiple times #222

Closed lgeiger closed 2 months ago

lgeiger commented 7 months ago

Description

contract("ij->jij", [[0, 0], [0, 0]])
# ValueError: einstein sum subscripts string includes output subscript 'j' multiple times

currently relies on the backend to throw an error if an output subscript is specified multiple times. However

contract_path("ij->jij", [[0, 0], [0, 0]])

would still return an einsum path despite the wrong einsum equation. This might lead to subtle errors if a user ends up relying on this behaviour by accident. E.g. when using the jax backend no error is thrown but only an internal assertion fails.

This PR raises the error already in parse_einsum_input which ensures that contract_path matches the behaviour of np.einsum and all backends that might rely on opt_einsum for error handling will do so as well.

This PR also extends the unittests to test error handling for both contract_path and contract which is independent of the backend.

For reference I made a similar PR to https://github.com/numpy/numpy/pull/25230 to fix the same issue for np.einsum_path.

Status

codecov[bot] commented 7 months ago

Codecov Report

Merging #222 (2e51e2b) into master (1a984b7) will increase coverage by 3.15%. The diff coverage is 100.00%.

Additional details and impacted files