microsoft / onnxscript

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python.
https://onnxscript.ai/
MIT License
287 stars 54 forks source link

tiiuae/falcon-rw-1b: Dynamo export error during optimization step #1543

Open asfiyab-nvidia opened 6 months ago

asfiyab-nvidia commented 6 months ago

Hi, I'm attempting to export the tiiuae/falcon-rw-1b model using Dynamo. I'm using the script below but run into an issue related to bfloat16 type. Is this a known issue?

import torch
from transformers import AutoConfig, AutoModelForCausalLM
from optimum.exporters.onnx.model_configs import FalconOnnxConfig

device = 'cuda'
model_name = "tiiuae/falcon-rw-1b" 

# load model
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float32).eval().to(device)

# get data
model_config = AutoConfig.from_pretrained(model_name, torch_dtype=torch.float32)
onnx_config = FalconOnnxConfig(model_config)
data = onnx_config.generate_dummy_inputs()
for k, v in data.items():
    data[k] = v.to(device)

export_output = torch.onnx.dynamo_export(
    model,
    **data
)
export_output.save('falcon-rw-1b_dynamo.onnx')

The ONNX export seems to be successful and the failure occurs during the optimization step. Below is the error stack:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/onnx/numpy_helper.py", line 388, in from_array
    dtype = helper.np_dtype_to_tensor_dtype(arr.dtype)
  File "/usr/local/lib/python3.10/dist-packages/onnx/helper.py", line 1599, in np_dtype_to_tensor_dtype
    mapping._NP_TYPE_TO_TENSOR_TYPE[np_dtype],
KeyError: dtype((numpy.uint16, [('bfloat16', '<u2')]))

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1503, in dynamo_export
    ).export()
  File "/usr/local/lib/python3.10/dist-packages/torch/onnx/_internal/exporter.py", line 1274, in export
    onnx_model = optimizer.optimize(onnx_model)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/optimizer/__init__.py", line 70, in optimize
    modified = fold_constants(
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/optimizer/constant_folding.py", line 272, in fold_constants
    folder.visit_model(model)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/optimizer/constant_folding.py", line 257, in visit_model
    super().visit_model(model)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/_legacy_ir/visitor.py", line 777, in visit_model
    self.visit_graph(model.graph)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/_legacy_ir/visitor.py", line 643, in visit_graph
    replacement = self.visit_node(node)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/_legacy_ir/visitor.py", line 790, in visit_node
    replacement, _ = self.process_function_node(node)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/optimizer/constant_folding.py", line 229, in process_function_node
    _, new_function = super().process_function_node(node)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/_legacy_ir/visitor.py", line 877, in process_function_node
    replacement = self.visit_node(inner_node)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/_legacy_ir/visitor.py", line 792, in visit_node
    replacement = self.process_node(node)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/optimizer/constant_folding.py", line 215, in process_node
    replacement = self.new_constant(node.output[0], outputs)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/optimizer/constant_folding.py", line 107, in new_constant
    tensor = self.foldable_value(name, value)
  File "/usr/local/lib/python3.10/dist-packages/onnxscript/optimizer/constant_folding.py", line 98, in foldable_value
    return onnx.numpy_helper.from_array(value, name)
  File "/usr/local/lib/python3.10/dist-packages/onnx/numpy_helper.py", line 390, in from_array
    raise RuntimeError(f"Numpy data type not understood yet: {arr.dtype!r}") from e
RuntimeError: Numpy data type not understood yet: dtype((numpy.uint16, [('bfloat16', '<u2')]))
justinchuby commented 6 months ago

Thanks for reporting. I think https://github.com/microsoft/onnxscript/pull/1484 is related. Do you know if the onnxscript version you have is the latest?

asfiyab-nvidia commented 6 months ago

Thanks for linking the related issue. I tried with the latest onnxscript==0.1.0.dev20240515 version and I see the changes in the linked PR are in the 20240515 nightly version. However, I'm still seeing the same error

justinchuby commented 6 months ago

cc @gramalingam

gramalingam commented 5 months ago

Hi @asfiyab-nvidia : can you attached the (unoptimized) onnx model here? That would be helpful. I believe that if the optimizer failes, it will still build an unoptimized onnx model. Thanks!

gramalingam commented 5 months ago

@justinchuby : while waiting for the model to repro, I wonder where "dtype((numpy.uint16, [('bfloat16', '<u2')]))" comes from ... it seems like ml_dtypes is a possible source for this? Even so, it doesn't add up ... I think ml_dtypes is used in the IR, right? But the constant-folding optimizer doesn't yet use the new IR ... oh, well, I guess I should try it out with the actual model.

justinchuby commented 5 months ago

Your are right that ml_dtypes doesn't kick in at this stage yet. It looks like a product from the reference evaluator (most likely due to a cast node). I suggest we use ml_dtypes in the reference evaluator (and across ONNX) as well.

gramalingam commented 5 months ago

You are right. The reference implementation does introduce this. That raises another question (which, I guess, is what motivates the second part of your answer): what bfloat16 encoding does the reference evaluator use? Is that a custom one that is conceptually a duplicate of the ml_dtypes one? I agree that it would be good to use a uniform encoding across all onnx tools/implementations.

justinchuby commented 5 months ago

The custom types for the ref evaluator are defined here: https://github.com/onnx/onnx/blob/88f8ef15cfaa3138d336f3502aed5018d802bf43/onnx/reference/custom_element_types.py#L8. They are simply byte representation that does not support any arithmetic operations.

With ml_dtypes computation will be supported, besides having the correct byte representation.

xadupre commented 5 months ago

This should be addressed by https://github.com/onnx/onnx/pull/6170.