jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.5k stars 2.8k forks source link

NotImplementedError: Call to scatter_(update/add/multiply/min/max) cannot be converted with enable_xla=False #15627

Open galah92 opened 1 year ago

galah92 commented 1 year ago

Description

I've got a JAX function that I'm trying to fuse with a Keras model and convert everything to TFLite. To my understanding the best way to do that is to convert my function to a TensorFlow Concrete Function, merge it with my Keras model to another Concrete Function, and eventually convert it to TFlite using TFLiteConverter.from_concrete_functions.

Problem is, getting a Concrete Function from my JAX function fails with the errors below. The weird thing is that converting this function to TFLite directly with TFLiteConverter.experimental_from_jax works! I assume it has something to do with conversion directly to HLO (reference), but still, is it possible to get the same behavior for from_concrete_functions?

import numpy as np
import jax.numpy as jnp
from jax.experimental import jax2tf
import numpy.typing as npt
import tensorflow as tf

grid_min = np.array([-1.5, 0, -0.2], dtype=np.float32)
grid_max = np.array([1.5, 3, 2], dtype=np.float32)
grid_res = np.array([0.15, 0.15, 0.1], dtype=np.float32)
image_shape = ((grid_max + grid_res - grid_min) / grid_res).astype(np.int32)

# start, end, step
x_range = np.array([-0.995, 0.995, 0.058500], dtype=np.float32)
y_range = np.array([-0.995, 0.995, 0.058500], dtype=np.float32)
z_range = np.array([0, 6.5169, 0.2327], dtype=np.float32)

def preprocess(
    dx_idx,
    dy_idx,
    r_idx,
    snr,
    sensor_loc,
    rotation_mat,
    arena_min,
    arena_max,
) -> npt.NDArray:
    # convert to spherical coordinates
    dx = x_range[0] + (dx_idx - 1) * x_range[2]
    dy = y_range[0] + (dy_idx - 1) * y_range[2]
    r = z_range[0] + (r_idx - 1) * z_range[2]

    # convert to cartesian coordinates
    peaks = jnp.stack([dx, dy, jnp.sqrt(1 - dx * dx - dy * dy)]) * r
    peaks = peaks.T

    # rotate and translate the peaks
    peaks = peaks @ rotation_mat + sensor_loc

    # mask out peaks that are outside of the arena
    arena_mask = ((arena_min <= peaks) & (peaks <= arena_max)).all(axis=1)

    # convert to image coordinates
    peaks = jnp.floor((peaks - grid_min) / grid_res).astype(np.int32)

    # mask out peaks that are outside of the arena
    image_mask = ((0 <= peaks) & (peaks < image_shape)).all(axis=1)

    # clip to zero peaks that are outside of the arena and image
    mask = arena_mask & image_mask
    x = jnp.where(mask, peaks[:, 0], 0)
    y = jnp.where(mask, peaks[:, 1], 0)
    z = jnp.where(mask, peaks[:, 2], 0)
    snr = jnp.where(mask, snr, 0)

    # create the image
    image = jnp.zeros(image_shape, dtype=np.float32)
    image.at[x, y, z].set(snr)

    return image

dx_idx = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=np.float32)
dy_idx = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=np.float32)
r_idx = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=np.float32)
snr = np.array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], dtype=np.float32)
sensor_loc = np.array([0, 0, 1.5], dtype=np.float32)
rotation_mat = np.array([[-0.7071, 0, -0.7071], [-0.7071, 0, 0.7071], [0, 1, 0]], dtype=np.float32)
arena_min = np.array([-1, 0.3, 0], dtype=np.float32)
arena_max = np.array([1, 2, 1.8], dtype=np.float32)

converter = tf.lite.TFLiteConverter.experimental_from_jax(
    [preprocess],
    [
        [
            ("dx_idx", dx_idx),
            ("dy_idx", dy_idx),
            ("r_idx", r_idx),
            ("snr", snr),
            ("sensor_loc", sensor_loc),
            ("rotation_mat", rotation_mat),
            ("arena_min", arena_min),
            ("arena_max", arena_max),
        ]
    ],
)
process_tflite = converter.convert()  # works!

preprocess_converted = jax2tf.convert(preprocess, enable_xla=False)
preprocess_tf = tf.function(preprocess_converted, autograph=False, jit_compile=True)
preprocess_tf_concrete = preprocess_tf.get_concrete_function(  # fails!
    tf.TensorSpec(shape=dx_idx.shape, name="dx_idx"),
    tf.TensorSpec(shape=dy_idx.shape, name="dy_idx"),
    tf.TensorSpec(shape=r_idx.shape, name="r_idx"),
    tf.TensorSpec(shape=snr.shape, name="snr"),
    tf.TensorSpec(shape=sensor_loc.shape, name="sensor_loc"),
    tf.TensorSpec(shape=rotation_mat.shape, name="rotation_mat"),
    tf.TensorSpec(shape=arena_min.shape, name="arena_min"),
    tf.TensorSpec(shape=arena_max.shape, name="arena_max"),
)
converter = tf.lite.TFLiteConverter.from_concrete_functions([preprocess_tf_concrete], preprocess_tf)
process_tflite = converter.convert()

Note: if it disable XLA (enable_xla=True) the conversion seemingly works but the final model I'm getting from from_concrete_functions has an empty graph.

Thank you.

What jax/jaxlib version are you using?

jax==0.4.8 jaxlib==0.4.7

Which accelerator(s) are you using?

CPU

Additional system info

WSL2 over Windows 11

NVIDIA GPU info

No response

hawkinsp commented 1 year ago

@gnecula Is this bug still current?

galah92 commented 1 year ago

Yes, still current with no apparent way overcoming it. I'll be happy to contribute a fix if you'll be able to guide me through.

gnecula commented 1 year ago

@ferev to see if these is a fix.

galah92 commented 11 months ago

@gnecula @Ferev Hi, I got back to this today and found out that experimental_from_jax is being deprecated and jax2tf.convert is the way to go now, which means I don't have a lot of time until I'll have to migrate. Can I help in some way with this issue?

gnecula commented 11 months ago

@ferev What is the status of TFLite support for StableHLO? I would like to deprecate the enable_xla=False path altogether.

zehuiw commented 11 months ago

Hi @gnecula , TFLite has migrated to use native serialization as the default. We can simply save as as TFSavedModel via tf.saved_model.save containing StableHLO ops as the intermediate format between JAX and TFLite. Then TFLite converter API will pick that up converter = tf.lite.TFLiteConverter.from_saved_model('/some/directory') tflite_model = converter.convert(). We will update our public documentation pretty soon.