huggingface / optimum

🚀 Accelerate training and inference of 🤗 Transformers and 🤗 Diffusers with easy to use hardware optimization tools
https://huggingface.co/docs/optimum/main/
Apache License 2.0
2.56k stars 462 forks source link

Decoder with cache ONNX export failed after mixed precision training #719

Open JingyaHuang opened 1 year ago

JingyaHuang commented 1 year ago

System Info

Optimum dev
onnxruntime 1.13.0

Who can help?

@JingyaHuang

Reproduction

Using evaluate() after mixed-precision training with ORTTrainer. But I think the problem is general for converting PyTorch models to ONNX with FP16 weights.

Error message

Traceback (most recent call last):
  File "test_onnxruntime_train.py", line 279, in test_ort_trainer_decoder
    ort_eval_metrics = trainer.evaluate(inference_with_ort=inference_with_ort)
  File "/workspace/optimum/onnxruntime/trainer.py", line 813, in evaluate
    output = eval_loop(
  File "/workspace/optimum/onnxruntime/trainer.py", line 969, in evaluation_loop_ort
    self._export(onnx_model_path, with_loss=with_loss, device=export_device)
  File "/workspace/optimum/onnxruntime/trainer.py", line 1501, in _export
    _ = export(
  File "/workspace/optimum/exporters/onnx/convert.py", line 607, in export
    return export_pytorch(model, config, opset, output, device=device, input_shapes=input_shapes)
  File "/workspace/optimum/exporters/onnx/convert.py", line 370, in export_pytorch
    onnx_export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/__init__.py", line 350, in export
    return utils.export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 163, in export
    _export(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 1074, in _export
    graph, params_dict, torch_out = _model_to_graph(
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 727, in _model_to_graph
    graph, params, torch_out, module = _create_jit_graph(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 602, in _create_jit_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args)
  File "/usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py", line 517, in _trace_and_get_graph_from_model
    trace_graph, torch_out, inputs_states = torch.jit._get_trace_graph(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 1175, in _get_trace_graph
    outs = ONNXTracedModule(f, strict, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 127, in forward
    graph, out = torch._C._create_graph_by_tracing(
  File "/usr/local/lib/python3.8/dist-packages/torch/jit/_trace.py", line 118, in wrapper
    outs.append(self.inner(*trace_inputs))
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 1047, in forward
    transformer_outputs = self.transformer(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 891, in forward
    outputs = block(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 392, in forward
    attn_outputs = self.attn(
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 1118, in _slow_forward
    result = self.forward(*input, **kwargs)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 333, in forward
    attn_output, attn_weights = self._attn(query, key, value, attention_mask, head_mask)
  File "/usr/local/lib/python3.8/dist-packages/transformers/models/gpt2/modeling_gpt2.py", line 185, in _attn
    attn_weights = torch.matmul(query, key.transpose(-1, -2))
RuntimeError: expected scalar type Float but found Half

Expected behavior

Decoder with past with fp16 weights can be successfully exported to ONNX model.

Contribution

I can take a closer look, but I don't have the bandwidth for the moment.

JingyaHuang commented 1 year ago

This is weird as I've verified that both query and key are float16 and on CUDA. Besides, the export of decoder without past doesn't have that issue. Will investigate further when I have the bandwidth...

[Update] I won't fix it unless large need from the community. As proper inference with ORT should go with subclasses of ORTModel instead of ORTTrainer. The inference part of ORTTrainer is just for the fast test.

JingyaHuang commented 1 year ago

Will test the export again once this PR #749 is merged.