iree-org / iree-jax

Apache License 2.0
48 stars 19 forks source link

Dynamic_api_test in jax failing when returning dynamic shape #57

Closed oliverdutton closed 1 year ago

oliverdutton commented 1 year ago

I am trying to use the (very convenient) option of iree as a jax backend. Running the tests, they seem to be failing when the output shape is dynamic. I'm guessing this test actually works but I'm missing something. The issue is presumably jax forcibly turning the result into a jax array.

Below runs just one of the tests, what am I doing wrong?

!cd ${JAX_REPO_PATH} && \
JAX_ARRAY=0 JAX_PLATFORMS=iree \
pytest -r a --verbosity 1 -s tests/dynamic_api_test.py -k transpose
============================= test session starts ==============================
platform linux -- Python 3.8.13, pytest-6.2.5, py-1.11.0, pluggy-1.0.0 -- /opt/conda/bin/python3.8
cachedir: .pytest_cache
hypothesis profile 'default' -> database=DirectoryBasedExampleDatabase('/raid/app/oliver/repos/jax/.hypothesis/examples')
rootdir: /raid/app/oliver/repos/jax, configfile: pytest.ini
plugins: anyio-3.6.2, pythonpath-0.7.4, cov-3.0.0, hypothesis-4.50.8
collected 76 items / 75 deselected / 1 selected                                

tests/dynamic_api_test.py::DynamicShapeTest::test_transpose FAILED

=================================== FAILURES ===================================
_______________________ DynamicShapeTest.test_transpose ________________________

self = <dynamic_api_test.DynamicShapeTest testMethod=test_transpose>

    @unittest.skipIf(jtu.device_under_test() != 'iree', 'iree test')
    def test_transpose(self):
      @partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},))
      def f(x):  # f32[h, w] -> f32[w, h]
        return x.T

>     f(np.ones((3, 5), dtype=np.float32))

tests/dynamic_api_test.py:673: 
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 
jax/_src/api.py:532: in f_jitted
    out_flat = xla.xla_call(
jax/_src/core.py:2132: in bind
    outs = top_trace.process_call(self, fun_, tracers, params)
jax/_src/core.py:794: in process_call
    return primitive.impl(f, *tracers, **params)
jax/_src/dispatch.py:258: in _xla_call_impl
    return compiled_fun(*args)
jax/_src/dispatch.py:916: in _execute_compiled
    return result_handler(env, out_bufs)
jax/_src/dispatch.py:787: in result_handler
    results.append(handler((input_env, results), *bufs))
_ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ 

sticky_device = None, aval = f32[InDBIdx(val=1),InDBIdx(val=0)]
env = ((3, 5, None), [])
buf = IreeBuffer([[1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]
 [1. 1. 1.]])

    def _dynamic_array_result_handler(sticky_device, aval, env, buf):
      in_env, out_env = env or (None, None)
      shape = [in_env[d.val] if type(d) is core.InDBIdx else
               out_env[d.val] if type(d) is core.OutDBIdx else d
               for d in aval.shape]
      if all(type(d) is int for d in shape) and type(aval.dtype) is not core.bint:
>       aval = core.ShapedArray(tuple(shape), buf.dtype)
E       AttributeError: 'IreeBuffer' object has no attribute 'dtype'

jax/_src/dispatch.py:846: AttributeError
=========================== short test summary info ============================
FAILED tests/dynamic_api_test.py::DynamicShapeTest::test_transpose - Attribut...
======================= 1 failed, 75 deselected in 1.15s =======================

The following can reproduce in colab (with jax v0.4.4)

!pip install git+https://github.com/iree-org/iree-jax

import os
os.environ['JAX_JIT_PJIT_API_MERGE'] = '0'

import jax
jax.config.update("jax_dynamic_shapes", True)
jax.config.update("jax_array", False)
jnp = jax.numpy
from functools import partial

@partial(jax.jit, abstracted_axes=({0: 'h', 1: 'w'},), backend='iree')
def f(x):  # f32[h, w] -> f32[w, h]
  return x.T

f(jnp.ones((3, 5), dtype=float))
jpienaar commented 1 year ago

There seems to be an expectation in Jax python here wrt buffer result types that we don't match here. I'll ask Matt too what is expected.

mattjj commented 1 year ago

Thanks for raising this. I think there should be an easy fix; I can take it.

(But the JAX bits being used here are super unpolished and feature coverage is still minimal!)

mattjj commented 1 year ago

I think https://github.com/google/jax/pull/14986 should fix. But to be honest I've paged out a lot of context on this! So if something else breaks just let us know.

oliverdutton commented 1 year ago

Perfect, thank you

mattjj commented 1 year ago

By the way, It's near the top of my todo list to update dynamic shapes to be compatible with both JAX_JIT_PJIT_API_MERGE=1 and JAX_ARRAY=1. It shouldn't be "hard", but it's nontrivial just because it'll require a big context switch and some time. It never rises to "urgent" like other things because we don't have any dynamic shapes users (or at least I thought we didn't have any until this issue was opened!).

oliverdutton commented 1 year ago

I completely understand, dynamic shape is very experimental. I've been poking around

I was looking at use cases popping foldcomp on GPU by combining with nerfax but compilation times kills it from being competitive due to slightly different length arrays everywhere.

And maybe jax-md can benefit from it in the neighbor list update.

But these are primarily interest projects, hopefully I'll find a meaty business application eventually

Thanks for all the magic of jax

mattjj commented 1 year ago

Thanks for the explanation, and the kind words!

I was looking at use cases popping foldcomp on GPU by combining with nerfax but compilation times kills it from being competitive due to slightly different length arrays everywhere.

Wow, this is very interesting. Any chance you could share some representative programs or toy examples, showing what you want to do, or where the compile times are killing you? Maybe we can help!

jpienaar commented 1 year ago

(this would also be interesting IREE side as we have a WIP pass to dedupe some kernels to dynamic dim variants to reduce compilation times, we've been focussed a bit more.in AOT case but still)

oliverdutton commented 1 year ago

Cool, I'll generate a discussion separately and tag you in it in the next few days with a clear set of code that I'm working on compiling.

Looks like the merge solves those tests, so closing issue