jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.34k stars 2.78k forks source link

[jax2tf] Error when converting function wrapped in jax.grad to SavedModel format #10725

Closed marcvanzee closed 2 years ago

marcvanzee commented 2 years ago

When trying to convert models in #9659, an issue was uncovered with converting to SavedModel format. Copying @oliverdutton's remark from there:

"The users needs to be aware that their input signature has changed to something flat, and ideally they need to name it in a way that makes sense. They just define the model and a way to pass args to it, which can be tricky when you’re using it in a number of platforms/languages.

There could be a utility function left around though, which people are advised to use which wraps f and inputs."

And from @josephrocca:

"would the most simple example code to add to the jax2tf readme be something like this?"

def your_fn(nested_data):
  # function that you want to convert to TensorFlow with jax2tf

nested_data = # add an input data sample here
flat_data = jax.tree_map(tf.constant, jax.tree_flatten(nested_data)[0])
flat_names = list(flatten(nested_data))
_, nested_data_structure = jax.tree_flatten(nested_data)

def your_fn_flat(flat_data):
  nested_data = jax.tree_unflatten(nested_data_structure, nested_data)
  return your_fn(nested_data)

your_fn_flat_tf = jax2tf.convert(your_fn_flat, enable_xla=False)

my_model = tf.Module()
my_model.f = tf.function(your_fn_flat_tf, autograph=False, jit_compile=True, input_signature=[
  jax.tree_map(lambda x, name: tf.TensorSpec(x.shape, x.dtype, name=name), flat_data, flat_names),
])

model_name = 'your_fn_flat'
tf.saved_model.save(my_model, model_name,  options=tf.saved_model.SaveOptions(experimental_custom_gradients=False), signatures=my_model.f.get_concrete_function())

I think one problem with that is in cases like this: list(flatten({'foo':{'bar':1}, 'foo_bar':2})), which returns ['foo_bar']. It's also just kind of complicated, and I'm concerned that it would be hard for many people to use even if that were put in the jax2tf readme, especially if they have multiple inputs, and perhaps also need to use _variables. But maybe most users of jax2tf won't be as much of a noob as me? I'm not sure.

I'm also wondering if the signatures=my_model.f.get_concrete_function() bit is necessary in tf.saved_model.save, since we've already defined input_signature during the tf.function creation? Strangely, if I add it to the CLIP notebook that I linked above, then I get this error during tf.saved_model.save:

ValueError: Got a non-Tensor value <tf.Operation 'PartitionedCall' type=PartitionedCall> for key 'output_0' in the output of the function __inference_score_tf_grad_95495 used to generate the SavedModel signature 'serving_default'. Outputs for functions used as signatures must be a single Tensor, a sequence of Tensors, or a dictionary from string to Tensor.
marcvanzee commented 2 years ago

@josephrocca I have tried isolating your issue from #9659, but could you help me by providing a somewhat minimal example of when you encounter the issue and what error message you got exactly?

josephrocca commented 2 years ago

I think might be a bug/issue with tf.lite.TFLiteConverter.from_saved_model and tf.lite.TFLiteConverter.from_concrete_function rather than jax2tf, but I'm not sure. The error occurs when trying to convert a SavedModel (that was created via jax2tf + tf.saved_model.save(...)) to tflite.

Originally it seemed like it had something to do with nesting of the inputs to the function that was saved, but here's a minimal replication of the error without any nesting of the function's input parameters:

https://colab.research.google.com/drive/1Te2uHCFOdZgRT-N4VsAeIGjASK80E4im

The exact error shown in the above notebook is "ValueError: Only support at least one signature key."

A potentially related problem is that if we use jax.grad(score) in the above example instead of GradientTape1, then there are no errors, but the tflite file is only 19kb:

https://colab.research.google.com/drive/1H3PwVAgPxjv6ldJB9k4gc4pakCW8Scoy

The only difference between the former and the latter Colab notebooks is that I've used jax.grad in the latter instead of TF's GradientTape to get the gradient of the score function. So I'd really have expected the same "no input node" error in the latter too - not sure what's going on there, but it might be a clue.

It initially seemed like this could be worked-around by using your suggestion of using tf.lite.TFLiteConverter.from_concrete_function instead of tf.lite.TFLiteConverter.from_saved_model, but from_concrete_function produces a ~500kb .tflite file (see your notebook - note that it's converting LiT, whereas the previous two Colabs that I linked are converting CLIP - I'm guessing from_concrete_function, like from_saved_model, would produce a ~19kb file if used on CLIP).

Also, note that Oliver did manage to get CLIP gradients working in tflite with tf.lite.TFLiteConverter.experimental_from_jax:

https://colab.research.google.com/drive/1elQpDL1ysaOMncMnnuHEalnh4HAAeC53

A relevant comment from @oliverdutton:

  • jax2tf involves mapping jaxpr of the program to tensorflow ops, saving a saved model then loading and passing it with tflite converter. When enable_xla=False is used (which for tflite it's not required) this involves mapping to high level tensorflow ops.

  • tflite.experimental_from_jax uses the HLO itself, and doesn't save any intermediate files.

While both are perfectly valid, the second seems like less possible pain points.

So I'm not sure if this is a problem with jax2tf or TensorFlow, but if this issue is kept open, the title should probably be changed to something vaguely like "Small tflite file (only input and output nodes) or no input node when converting function wrapped in jax.grad to tflite via jax2tf".

1 I previously had to use GradientTape because of the lack of scatter add support, which Oliver's PR (#10653) seems to have fixed.

marcvanzee commented 2 years ago

Hi @josephrocca, I believe the small filesize is due to the fact that there is a bug in the jax2tf conversion of jax.grad with enable_xla=False, which causes all outputs to be zero, so TFLite optimizes this by removing the params (which are simply unused constants then). (see #10819).

I would argue that when converting to TFLite you should really not have to create a SavedModel at all. If you define your apply function by closing over your model's parameters, they will be stored as constants in the TF graph, and TFLite will treat them as constants as well, which seems fine for you (you don't want to do any training it seems).

@oliverdutton do you think there is anything else problematic with converting JAX to SavedModel? I was under the impression that you believed we need to do some tree unflattening, and I am trying to understand when this is the case, or whether this was a red herring.

Thanks in advance!

marcvanzee commented 2 years ago

Closing this since it isn't clear there is a problem with SavedModel, and as I said before the recommended way of converting a JAX function to TFLite is either using tf.lite.experimental_from_jax or using tf.lite.from_concrete_function.

@josephrocca please re-open if you are still encountering problems here, or let me know if you think I misunderstood soemthing!