onnx / tensorflow-onnx

Convert TensorFlow, Keras, Tensorflow.js and Tflite models to ONNX
Apache License 2.0
2.33k stars 432 forks source link

Getting "exceeds maximum protobuf size of 2GB" error even though I'm using `large_model=True` #2062

Open josephrocca opened 2 years ago

josephrocca commented 2 years ago

Describe the bug I'm getting "exceeds maximum protobuf size of 2GB" error even though I'm using large_model=True:

ValueError: Message tensorflow.GraphDef exceeds maximum protobuf size of 2GB: 2155786874

(Colab reproduction link is below)

Urgency No hard deadlines. For context, I'm trying to get Stable Diffusion working in the browser, and this problem is a blocker. It's mainly "for fun" (like this, which also uses ORT Web), but I think it'd be a good 'advertisement' for ONNX Runtime, even if it takes 5 mins per image until WebGPU comes along 😁

System information

To Reproduce Minimal reproduction: https://colab.research.google.com/gist/josephrocca/efcdcf8b69b5705f9b44b057aeeb90c7/stable_diffusion_jax-to-onnx.ipynb

Additional Context Related pull requests/issues:


(As context, for anyone curious: The reason I'm converting via JAX instead of just using the already-existing ONNX ports of this model is because the JAX code conversion process can bundle all the "admin"/scheduling code that sits around the model so I don't have to implement it all in JavaScript to get it working with ORT Web in the browser.)

josephrocca commented 2 years ago

Weirdly, I solved this by switching from a bfloat16 model (which I now realise is unsupported by tf2onnx) to a float32 version of that same model.

I guess there's some sort of "silent fallback" to large_model=False that's triggered by the unsupported bfloat16.

Given that the error message doesn't seem to match what the problem actually is, I'll leave this open for a tf2onnx dev to look at and close when appropriate.

fatcat-z commented 2 years ago

This error actually was thrown out by tensorflow.import_graph_def(). Before we call it, during the preparation of the frozen_graph, we don't see a difference for bfloat16 and float32. It looks like a tensorflow issue.

In addition, TF2ONNX has provided bfloat16 support since this PR in a way that mapping it to FLOAT16 type. Could you please share more about why you mentioned bfloat16 is not supported?

josephrocca commented 2 years ago

@fatcat-z The reason I thought bfloat16 wasn't supported is explained here: https://github.com/onnx/tensorflow-onnx/issues/2064

In the example notebook linked in my post above, if you change this line:

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None)

to this:

pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="flax", dtype=jnp.float32, safety_checker=None)

Then that downloads f32 weights instead of b16. Given that this fixed the "2GB limit" issue, I assumed that it was something to do with that "Numpy data type not understood yet" issue that I linked above.

josephrocca commented 2 years ago

(Woops, revision="flax", not revision="jax". I've fixed the second code block in the above comment.)

fatcat-z commented 2 years ago

Do we still have questions left for this issue?

josephrocca commented 2 years ago

I guess one possible loose end here is that the issue with tensorflow.import_graph_def() could be making the large_model=True option useless for some subset of models? I'm guessing that the reason b16 worked in this case was because it halved the model size, thereby putting it under the 2GB limit imposed by tensorflow.import_graph_def()?

The offending code is this line within the _import_graph_def_internal function in tensorflow/python/framework/importer.py. Here are the full logs for quick reference:

ValueError                                Traceback (most recent call last)
<ipython-input-6-7131e4dd730a> in <module>
     29     tf.TensorSpec([1, 4, 64, 64], tf.float32, name="latents"),
     30     tf.TensorSpec([2, 77, 768], tf.float32, name="context"),
---> 31 ], large_model=True, opset=16, output_path="unet_test.onnx")

/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py in from_function(function, input_signature, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
    577             tensors_to_rename=tensors_to_rename,
    578             initialized_tables=initialized_tables,
--> 579             output_path=output_path)
    580 
    581         return model_proto, external_tensor_storage

/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py in _convert_common(frozen_graph, name, large_model, output_path, output_frozen_graph, custom_ops, custom_op_handlers, optimizers, **kwargs)
    161             utils.save_protobuf(output_frozen_graph, frozen_graph)
    162         if not kwargs.get("tflite_path") and not kwargs.get("tfjs_path"):
--> 163             tf.import_graph_def(frozen_graph, name='')
    164         g = process_tf_graph(tf_graph, const_node_values=const_node_values,
    165                              custom_op_handlers=custom_op_handlers, **kwargs)

/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
    559                 'in a future version' if date is None else ('after %s' % date),
    560                 instructions)
--> 561       return func(*args, **kwargs)
    562 
    563     doc = _add_deprecated_arg_notice_to_docstring(

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/importer.py in import_graph_def(***failed resolving arguments***)
    406       return_elements=return_elements,
    407       name=name,
--> 408       producer_op_list=producer_op_list)
    409 
    410 

/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, producer_op_list)
    495   # _ProcessNewOps.
    496   with graph._mutation_lock():  # pylint: disable=protected-access
--> 497     with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
    498       try:
    499         results = c_api.TF_GraphImportGraphDefWithResults(

ValueError: Message tensorflow.GraphDef exceeds maximum protobuf size of 2GB: 2155786874

Please feel free to close this if it's not likely to be a general issue for some reason, or is tracked elsewhere.