jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.45k stars 2.8k forks source link

[jax2tf] Quantized model is blowing up the memory where JAX - Flax model was working perfectly #11291

Open deshwalmahesh opened 2 years ago

deshwalmahesh commented 2 years ago

I'm using a model for auto image enhancement google-research/maxim and it is working perfectly. So I was working with quantization of model and got the answer from the official sources on How to convert JAX model to tflite and I it worked. Code for the MAXIM to quantize, I have answered my own question on stackoverflow

Code to quantize:

import tensorflow as tf
from jax.experimental import jax2tf

def predict(input_img):
  '''
  Function to predict the output from the JAX model
  '''
  return model.apply({'params': flax.core.freeze(params)}, input_img)

tf_predict = tf.function(
    jax2tf.convert(predict, enable_xla=False),
    input_signature=[
        tf.TensorSpec(shape=[1, 704, 1024, 3], dtype=tf.float32, name='input_image')
    ],
    autograph=False)

converter = tf.lite.TFLiteConverter.from_concrete_functions(
    [tf_predict.get_concrete_function()], tf_predict)

converter.target_spec.supported_ops = [
    tf.lite.OpsSet.TFLITE_BUILTINS,  # enable TensorFlow Lite ops.
    tf.lite.OpsSet.SELECT_TF_OPS  # enable TensorFlow ops.
]
tflite_float_model = converter.convert()

with open('float_model.tflite', "wb") as f: f.write(tflite_float_model)

converter.optimizations = [tf.lite.Optimize.DEFAULT]
tflite_quantized_model = converter.convert()

with open('./quantized.tflite', 'wb') as f: f.write(tflite_quantized_model)

Problem: Everything is fine, the model is being loaded and showing input and output shapes as:

tflite_interpreter_quant = tf.lite.Interpreter(model_path='./maxim/quantized.tflite')
input_details = tflite_interpreter_quant.get_input_details()
output_details = tflite_interpreter_quant.get_output_details()

print("== Input details ==")
print("name:", input_details[0]['name'])
print("shape:", input_details[0]['shape'])
print("type:", input_details[0]['dtype'])

print("\n== Output details ==")
print("name:", output_details[0]['name'])
print("shape:", output_details[0]['shape'])
print("type:", output_details[0]['dtype'])

but when I do:

tflite_interpreter_quant.allocate_tensors()

The memory runs out. This is super weird because my original JAX model was running quite fine but as soon as I try to allocate memory to the QUANTIZED version, I get this one. Any idea why this might be happening?

marcvanzee commented 1 year ago

@Ferev this issue seems to be on the TFLite side, perhaps you could reroute it to someone from the TFLite team?

jaeyoo commented 1 year ago

Hi, @deshwalmahesh , could you please share your pretrained maxim model? the colab you pointed (https://github.com/google-research/maxim/blob/main/colab_inference_demo.ipynb) seems broken

jaeyoo commented 1 year ago

Never mind, I got the model. I could reproduce OOM in 12 GB RAM colab runtime environment. let me work on it more.

jaeyoo commented 1 year ago

I've tested in the newly TF version 2.11.0 and the issues gone. could you please test it in the new version of TF in the colab?