Open tzunghanjuang opened 3 months ago
Is this something that can only be fixed upstream in JAX?
The new lowerings assert that the shape information must be static. If we want to pass arguments without static_argnums
, the error will be triggered. To allow dynamic shapes, we have to add a case to the new lowerings so that the old _nary_lower_hlo
function (which allows dynamic shapes) can be used. So upstream is still required.
Just to add here, I think I fixed it upstream with this patch https://github.com/jax-ml/jax/pull/23886
Issue description
After updating jax and mlir dependency chain to v0.4.28 (PR#931), jax introduces new
_sin_lowering
and_cos_lowering
with fails with dynamic shapes.In the following code from jax._src.lax.lax,
mlir.lower_fun
triggers the error. To get rid of this, we temporarily patch these lowering with old-version function (_nary_lower_hlo
).Relevent Jax PR:
https://github.com/google/jax/commit/6d8b3e4cff97d966e56670e70957334885439b76
Source code and tracebacks
Example: https://github.com/PennyLaneAI/catalyst/blob/5fa4b21922ab1e7beb8f83cd1a8daf4b0c298c95/frontend/test/pytest/test_jax_dynamic_api.py#L140-L157
Trace: