microsoft / onnxscript

ONNX Script enables developers to naturally author ONNX functions and models using a subset of Python.
https://onnxscript.ai/
MIT License
285 stars 54 forks source link

Optimizer fails for bfloat16 models #1893

Open justinchuby opened 1 month ago

justinchuby commented 1 month ago

When the model uses blfloat16 ops, the optimizer fails with the following. We should handle custom types form onnx in _constant_folding

Traceback (most recent call last):
  File "/workspace/ONNXConverter/llama.py", line 36, in <module>
    onnxscript.optimizer.optimize_ir(program.model)
  File "/workspace/onnxscript/onnxscript/optimizer/__init__.py", line 124, in optimize_ir
    _constant_folding.fold_constants(model, onnx_shape_inference=onnx_shape_inference)
  File "/workspace/onnxscript/onnxscript/optimizer/_constant_folding.py", line 719, in fold_constants
    folder.visit_model(model)
  File "/workspace/onnxscript/onnxscript/optimizer/_constant_folding.py", line 699, in visit_model
    self.visit_graph(model.graph)
  File "/workspace/onnxscript/onnxscript/optimizer/_constant_folding.py", line 690, in visit_graph
    self.visit_node(node, graph)
  File "/workspace/onnxscript/onnxscript/optimizer/_constant_folding.py", line 679, in visit_node
    replacement = self.process_node(node)
                  ^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/onnxscript/onnxscript/optimizer/_constant_folding.py", line 650, in process_node
    replacement = self.new_constant(node.outputs[0], outputs)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/onnxscript/onnxscript/optimizer/_constant_folding.py", line 580, in new_constant
    irvalue.const_value = _convenience.tensor(value)
                          ^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/onnxscript/onnxscript/ir/_convenience.py", line 357, in tensor
    tensor_ = _core.Tensor(value, dtype=dtype, name=name, doc_string=name)
              ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/onnxscript/onnxscript/ir/_core.py", line 355, in __init__
    self._dtype = _enums.DataType.from_numpy(value.dtype)
                  ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/workspace/onnxscript/onnxscript/ir/_enums.py", line 76, in from_numpy
    raise TypeError(f"Unsupported numpy data type: {dtype}")
TypeError: Unsupported numpy data type: (numpy.uint16, [('bfloat16', '<u2')])
justinchuby commented 1 month ago

cc @gramalingam