microsoft / onnxruntime

ONNX Runtime: cross-platform, high performance ML inferencing and training accelerator
https://onnxruntime.ai
MIT License
14.67k stars 2.93k forks source link

The output for GPT is NAN when fp16=True #6328

Open leoozy opened 3 years ago

leoozy commented 3 years ago

Describe the bug

I convert a pytorch model (the openai GPT from the Huggingface package) to onnx and run it with onnxruntime. If I turn the fp16 off, the output of onnxruntime is correct. But when I optimize the onnx model with fp16, the output is all NAN

I optimize my onnnx model with the following code python -m onnxruntime_tools.optimizer_cli --input '/home/ubuntu/psy/utils/PsyDial/GPT/onnx/GPT1_1_opset11.onnx' --output ./optimized.onnx --model_type gpt2 --hidden_size 768 --num_heads 12 --use_gpu --float16 --opt_level 99

Urgency If there are particular important use cases blocked by this or strict project-related timelines, please share more information and dates. If there are no hard deadlines, please specify none.

System information

To Reproduce This is my code: ` enable_overwrite = True total_samples = 100 opset_version = 11 output_dir = "./onnx"

    if not os.path.exists(output_dir):
        os.makedirs(output_dir)

    export_model_path = os.path.join(output_dir, 'GPT1_1_opset{}.onnx'.format(opset_version))
    optimized_model_path="./optimized.onnx"

    inputs = {
        'input_ids': input_ids,
        'token_type_ids':token_type_ids
    }

    symbolic_names = {1: 'max_seq_len'}
    if enable_overwrite or not os.path.exists(export_model_path):
        with  torch.no_grad():
            torch.onnx.export(
                model,
                args=tuple(inputs.values()),
                opset_version=opset_version,
                example_outputs=logits,
                do_constant_folding=True,
                f=export_model_path,
                input_names=[
                    "input_ids",
                    "token_type_ids"
                ],
                output_names=[
                    "logits"
                ],
                dynamic_axes={"input_ids": symbolic_names,
                              "token_type_ids":symbolic_names,
                              "logits":symbolic_names}
            )

    def create_model_for_provider(model_path: str, provider: str) -> InferenceSession:
        assert provider in get_all_providers(), f"provider {provider} not found, {get_all_providers()}"

        options = SessionOptions()
        options.intra_op_num_threads = 1
        options.graph_optimization_level = GraphOptimizationLevel.ORT_ENABLE_ALL

        session = InferenceSession(model_path, options)
        session.disable_fallback()

        return session

    from onnxruntime_tools import optimizer
    optimized_model = optimizer.optimize_model(export_model_path,

                                               model_type='gpt2',
                                               num_heads=12,
                                               hidden_size=768,
                                               use_gpu=True,
                                               )

    optimized_model.convert_model_float32_to_float16()
    optimized_model.save_model_to_file(optimized_model_path)

    session = create_model_for_provider(optimized_model_path, "CUDAExecutionProvider")
    import time
    latency = []
    io_binding = session.io_binding()
    io_binding.bind_cpu_input('input_ids', input_ids.cpu().numpy())
    io_binding.bind_cpu_input('token_type_ids', token_type_ids.cpu().numpy())
    io_binding.bind_output('logits')
    for i in range(total_samples):
        start = time.time()
        session.run_with_iobinding(io_binding)

        torch.cuda.synchronize()
        latency.append(time.time() - start)
    print(sum(latency)  / len(latency))
    outputs = io_binding.copy_outputs_to_cpu()[0]
    print("***** Verifying correctness *****")
    outputs = outputs[0]
    for i in range(2):

        print('PyTorch and ONNX Runtime output {} are close:'.format(i),
              numpy.allclose(logits[i].cpu().numpy(), outputs[i], rtol=1e-02, atol=1e-02))
        diff = outputs[i] - logits[i].cpu().numpy()
        max_diff = numpy.max(numpy.abs(diff))
        avg_diff = numpy.average(numpy.abs(diff))
        print(f'maximum_diff={max_diff} average_diff={avg_diff}')

    pdb.set_trace()

`

Expected behavior A clear and concise description of what you expected to happen.

Screenshots If applicable, add screenshots to help explain your problem.

Additional context Add any other context about the problem here. If the issue is about a particular model, please share the model details as well to facilitate debugging.

tianleiwu commented 3 years ago

For GPT-2, you can refer to our notebook: https://github.com/microsoft/onnxruntime/blob/master/onnxruntime/python/tools/transformers/notebooks/Inference_GPT2_with_OnnxRuntime_on_CPU.ipynb