Open josephrocca opened 2 years ago
Weirdly, I solved this by switching from a bfloat16
model (which I now realise is unsupported by tf2onnx) to a float32
version of that same model.
I guess there's some sort of "silent fallback" to large_model=False
that's triggered by the unsupported bfloat16
.
Given that the error message doesn't seem to match what the problem actually is, I'll leave this open for a tf2onnx dev to look at and close when appropriate.
This error actually was thrown out by tensorflow.import_graph_def(). Before we call it, during the preparation of the frozen_graph, we don't see a difference for bfloat16 and float32. It looks like a tensorflow issue.
In addition, TF2ONNX has provided bfloat16 support since this PR in a way that mapping it to FLOAT16 type. Could you please share more about why you mentioned bfloat16 is not supported?
@fatcat-z The reason I thought bfloat16 wasn't supported is explained here: https://github.com/onnx/tensorflow-onnx/issues/2064
In the example notebook linked in my post above, if you change this line:
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="bf16", dtype=jnp.bfloat16, safety_checker=None)
to this:
pipeline, params = FlaxStableDiffusionPipeline.from_pretrained("CompVis/stable-diffusion-v1-4", revision="flax", dtype=jnp.float32, safety_checker=None)
Then that downloads f32
weights instead of b16
. Given that this fixed the "2GB limit" issue, I assumed that it was something to do with that "Numpy data type not understood yet" issue that I linked above.
(Woops, revision="flax"
, not revision="jax"
. I've fixed the second code block in the above comment.)
Do we still have questions left for this issue?
I guess one possible loose end here is that the issue with tensorflow.import_graph_def()
could be making the large_model=True
option useless for some subset of models? I'm guessing that the reason b16
worked in this case was because it halved the model size, thereby putting it under the 2GB limit imposed by tensorflow.import_graph_def()
?
The offending code is this line within the _import_graph_def_internal
function in tensorflow/python/framework/importer.py
. Here are the full logs for quick reference:
ValueError Traceback (most recent call last)
<ipython-input-6-7131e4dd730a> in <module>
29 tf.TensorSpec([1, 4, 64, 64], tf.float32, name="latents"),
30 tf.TensorSpec([2, 77, 768], tf.float32, name="context"),
---> 31 ], large_model=True, opset=16, output_path="unet_test.onnx")
/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py in from_function(function, input_signature, opset, custom_ops, custom_op_handlers, custom_rewriter, inputs_as_nchw, outputs_as_nchw, extra_opset, shape_override, target, large_model, output_path)
577 tensors_to_rename=tensors_to_rename,
578 initialized_tables=initialized_tables,
--> 579 output_path=output_path)
580
581 return model_proto, external_tensor_storage
/usr/local/lib/python3.7/dist-packages/tf2onnx/convert.py in _convert_common(frozen_graph, name, large_model, output_path, output_frozen_graph, custom_ops, custom_op_handlers, optimizers, **kwargs)
161 utils.save_protobuf(output_frozen_graph, frozen_graph)
162 if not kwargs.get("tflite_path") and not kwargs.get("tfjs_path"):
--> 163 tf.import_graph_def(frozen_graph, name='')
164 g = process_tf_graph(tf_graph, const_node_values=const_node_values,
165 custom_op_handlers=custom_op_handlers, **kwargs)
/usr/local/lib/python3.7/dist-packages/tensorflow/python/util/deprecation.py in new_func(*args, **kwargs)
559 'in a future version' if date is None else ('after %s' % date),
560 instructions)
--> 561 return func(*args, **kwargs)
562
563 doc = _add_deprecated_arg_notice_to_docstring(
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/importer.py in import_graph_def(***failed resolving arguments***)
406 return_elements=return_elements,
407 name=name,
--> 408 producer_op_list=producer_op_list)
409
410
/usr/local/lib/python3.7/dist-packages/tensorflow/python/framework/importer.py in _import_graph_def_internal(graph_def, input_map, return_elements, validate_colocation_constraints, name, producer_op_list)
495 # _ProcessNewOps.
496 with graph._mutation_lock(): # pylint: disable=protected-access
--> 497 with c_api_util.tf_buffer(graph_def.SerializeToString()) as serialized:
498 try:
499 results = c_api.TF_GraphImportGraphDefWithResults(
ValueError: Message tensorflow.GraphDef exceeds maximum protobuf size of 2GB: 2155786874
Please feel free to close this if it's not likely to be a general issue for some reason, or is tracked elsewhere.
Describe the bug I'm getting "exceeds maximum protobuf size of 2GB" error even though I'm using
large_model=True
:(Colab reproduction link is below)
Urgency No hard deadlines. For context, I'm trying to get Stable Diffusion working in the browser, and this problem is a blocker. It's mainly "for fun" (like this, which also uses ORT Web), but I think it'd be a good 'advertisement' for ONNX Runtime, even if it takes 5 mins per image until WebGPU comes along 😁
System information
To Reproduce Minimal reproduction: https://colab.research.google.com/gist/josephrocca/efcdcf8b69b5705f9b44b057aeeb90c7/stable_diffusion_jax-to-onnx.ipynb
Additional Context Related pull requests/issues:
(As context, for anyone curious: The reason I'm converting via JAX instead of just using the already-existing ONNX ports of this model is because the JAX code conversion process can bundle all the "admin"/scheduling code that sits around the model so I don't have to implement it all in JavaScript to get it working with ORT Web in the browser.)