openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.38k stars 355 forks source link

TF official objection detection model fails when running with XLA on GPU #14165

Open othakkar opened 3 days ago

othakkar commented 3 days ago

I'm running the faster_rcnn_inception_resnet_v2_atrous_coco model from TF official objection detection model zoo and I see the following error when running it with XLA enabled using the TF-XLA flags (TF_XLA_FLAGS='--tf_xla_auto_jit=1 --tf_xla_cpu_global_jit'):

2024-06-25 21:32:56.781722: I tensorflow/core/util/port.cc:153] oneDNN custom operations are on. You may see slightly different numerical results due to floating-point round-off errors from different computation orders. To turn them off, set the environment variable `TF_ENABLE_ONEDNN_OPTS=0`.
2024-06-25 21:32:56.821894: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: SSE3 SSE4.1 SSE4.2 AVX AVX2 AVX512F AVX512_VNNI FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
/usr/local/lib/python3.10/dist-packages/scipy/__init__.py:146: UserWarning: A NumPy version >=1.16.5 and <1.23.0 is required for this version of SciPy (detected version 1.26.4
  warnings.warn(f"A NumPy version >={np_minversion} and <{np_maxversion}"
input shape:  [1, 300, 300, 3]
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1719351181.580014  104329 gpu_device.cc:2022] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 15513 MB memory:  -> device: 0, name: Tesla P100-PCIE-16GB, pci bus id: 0000:af:00.0, compute capability: 6.0
2024-06-25 21:33:02.209152: I tensorflow/compiler/mlir/mlir_graph_optimization_pass.cc:390] MLIR V1 optimization pass is not enabled
WARNING: All log messages before absl::InitializeLog() is called are written to STDERR
I0000 00:00:1719351191.283694  104612 service.cc:148] XLA service 0x7f8a24002e10 initialized for platform CUDA (this does not guarantee that XLA will be used). Devices:
I0000 00:00:1719351191.283758  104612 service.cc:156]   StreamExecutor device (0): Tesla P100-PCIE-16GB, Compute Capability 6.0
I0000 00:00:1719351191.307492  104612 cuda_dnn.cc:530] Loaded cuDNN version 8906
I0000 00:00:1719351191.345117  104612 device_compiler.h:188] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
2024-06-25 21:33:11.898256: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-06-25 21:33:22.596397: W external/local_xla/xla/tsl/framework/bfc_allocator.cc:291] Allocator (GPU_0_bfc) ran out of memory trying to allocate 17.70GiB with freed_by_count=0. The caller indicates that this is not a failure, but this may mean that there could be performance gains if more memory were available.
2024-06-25 21:33:31.644586: I external/local_xla/xla/stream_executor/cuda/cuda_asm_compiler.cc:396] ptxas warning : Registers are spilled to local memory in function 'loop_add_fusion_52', 16 bytes spill stores, 16 bytes spill loads
ptxas warning : Registers are spilled to local memory in function 'loop_add_fusion_50', 8 bytes spill stores, 8 bytes spill loads

2024-06-25 21:33:31.678591: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INTERNAL: CustomCall failed: Buffers have different size at runtime
     [[{{node cluster_9_1/xla_run}}]]
2024-06-25 21:33:31.678627: I tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: INTERNAL: CustomCall failed: Buffers have different size at runtime
     [[{{node cluster_9_1/xla_run}}]]
     [[prefix/SecondStagePostprocessor/BatchMultiClassNonMaxSuppression/map/TensorArrayUnstack/strided_slice/declustered/_119]]
2024-06-25 21:33:31.678639: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 13029241050824245345
2024-06-25 21:33:31.678646: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 15200758161930789551
2024-06-25 21:33:31.678653: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 7001848036408623967
2024-06-25 21:33:31.678660: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17175913279372690319
2024-06-25 21:33:31.678670: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4012875769220498887
2024-06-25 21:33:31.678678: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 10202096503040189339
2024-06-25 21:33:31.678690: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17017288333246007419
2024-06-25 21:33:31.678698: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 12465475364722502465
2024-06-25 21:33:31.678705: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 16947449311890829889
2024-06-25 21:33:31.678729: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 7835899072589329893
2024-06-25 21:33:31.678738: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 14044228672765892233
2024-06-25 21:33:31.678746: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 14724706182550214325
2024-06-25 21:33:31.678754: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 1303585264248567919
2024-06-25 21:33:31.678760: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 8245635139670940535
2024-06-25 21:33:31.678767: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 16092000884064566713
2024-06-25 21:33:31.678775: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 12212152217892593085
2024-06-25 21:33:31.678799: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 3172997325053386311
2024-06-25 21:33:31.678813: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 17561650878640874064
2024-06-25 21:33:31.678820: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 12809026326892602528
2024-06-25 21:33:31.678828: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 15718700352859022928
2024-06-25 21:33:31.678836: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 2954094825752635354
2024-06-25 21:33:31.678844: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 240346711817677520
2024-06-25 21:33:31.678852: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 15813373048961100444
2024-06-25 21:33:31.678860: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 9589657468674698308
2024-06-25 21:33:31.678868: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 12122418290205691344
2024-06-25 21:33:31.678876: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 16597897312858091426
2024-06-25 21:33:31.678883: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 16783744011106448676
2024-06-25 21:33:31.678891: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5470852174316402086
2024-06-25 21:33:31.678899: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 7390005538945226880
2024-06-25 21:33:31.678906: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 15560607185900304684
2024-06-25 21:33:31.678935: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 4829869614833859416
2024-06-25 21:33:31.678943: I tensorflow/core/framework/local_rendezvous.cc:423] Local rendezvous recv item cancelled. Key hash: 5225834741859597468
Traceback (most recent call last):
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1401, in _do_call
    return fn(*args)
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1384, in _run_fn
    return self._call_tf_sessionrun(options, feed_dict, fetch_list,
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1477, in _call_tf_sessionrun
    return tf_session.TF_SessionRun_wrapper(self._session, options, feed_dict,
tensorflow.python.framework.errors_impl.InternalError: 2 root error(s) found.
  (0) INTERNAL: CustomCall failed: Buffers have different size at runtime
     [[{{node cluster_9_1/xla_run}}]]
     [[prefix/SecondStagePostprocessor/BatchMultiClassNonMaxSuppression/map/TensorArrayUnstack/strided_slice/declustered/_119]]
  (1) INTERNAL: CustomCall failed: Buffers have different size at runtime
     [[{{node cluster_9_1/xla_run}}]]
0 successful operations.
0 derived errors ignored.

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
  File "/host/frameworks.ai.benchmarking.extended-broad-product/TF/workload/PB/inference.py", line 45, in <module>
    run_inference(frozen_graph_filename, input_node_name, output_node_name)
  File "/host/frameworks.ai.benchmarking.extended-broad-product/TF/workload/PB/inference.py", line 36, in run_inference
    output_vals = sess.run(output_tensors, feed_dict=feed_dict)
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 971, in run
    result = self._run(None, fetches, feed_dict, options_ptr,
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1214, in _run
    results = self._do_run(handle, final_targets, final_fetches,
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1394, in _do_run
    return self._do_call(_run_fn, feeds, fetches, targets, options,
  File "/usr/local/lib/python3.10/dist-packages/tensorflow/python/client/session.py", line 1420, in _do_call
    raise type(e)(node_def, op, message)  # pylint: disable=no-value-for-parameter
tensorflow.python.framework.errors_impl.InternalError: Graph execution error:

2 root error(s) found.
  (0) INTERNAL: CustomCall failed: Buffers have different size at runtime
     [[{{node cluster_9_1/xla_run}}]]
     [[prefix/SecondStagePostprocessor/BatchMultiClassNonMaxSuppression/map/TensorArrayUnstack/strided_slice/declustered/_119]]
  (1) INTERNAL: CustomCall failed: Buffers have different size at runtime
     [[{{node cluster_9_1/xla_run}}]]
0 successful operations.
0 derived errors ignored.

Note that the model runs fine without enabling XLA.

Steps to reproduce:

  1. Use tf-nightly (06/25/24) with GPU support - I built TF from source using bazel.
  2. Download the faster_rcnn_inception_resnet_v2_atrous_coco model from TF official models - http://download.tensorflow.org/models/object_detection/faster_rcnn_inception_resnet_v2_atrous_coco_2018_01_28.tar.gz
  3. Untar the model and there should be frozen_inference_graph.pb file.
  4. Enable XLA: export TF_XLA_FLAGS='--tf_xla_auto_jit=1 --tf_xla_cpu_global_jit'
  5. Run the script below that uses the frozen pb to perform inference with dummy data.
import tensorflow as tf
import numpy as np

def load_graph(frozen_graph_filename):
    # Load the protobuf file from the disk and parse it to retrieve the unserialized graph_def
    with tf.io.gfile.GFile(frozen_graph_filename, "rb") as f:
        graph_def = tf.compat.v1.GraphDef()
        graph_def.ParseFromString(f.read())

    # Import the graph_def into a new Graph and return it
    with tf.Graph().as_default() as graph:
        tf.import_graph_def(graph_def, name="prefix")
    return graph

def run_inference(frozen_graph_filename, input_node_names, output_node_names):
    # Load the frozen graph
    graph = load_graph(frozen_graph_filename)

    # Create dummy inputs based on the input shapes from the graph
    input_tensors = []
    feed_dict = {}
    with graph.as_default():
        for input_node_name in input_node_names:
            input_tensor = graph.get_tensor_by_name("prefix/" + input_node_name + ":0")
            # hardcoding input shape for simplicity
            input_shape = [1, 300, 300, 3]
            print("input shape: ", input_shape)
            dummy_input = np.random.random(input_shape).astype(np.float32)
            input_tensors.append(input_tensor)
            feed_dict[input_tensor] = dummy_input

        # We launch a Session
        with tf.compat.v1.Session(graph=graph) as sess:
            output_tensors = [graph.get_tensor_by_name("prefix/" + name + ":0") for name in output_node_names]
            output_vals = sess.run(output_tensors, feed_dict=feed_dict)
            output_dict = dict(zip(output_node_names, output_vals))
            print("Output Values:", output_dict)
            return output_dict

# Example usage
frozen_graph_filename = '/host/faster_rcnn_inception_resnet_v2_atrous_coco_2018_01_28/frozen_inference_graph.pb'
input_node_name = ['image_tensor']
output_node_name = ['detection_boxes', 'detection_scores', 'num_detections']
run_inference(frozen_graph_filename, input_node_name, output_node_name)
othakkar commented 3 days ago

Upon running the same model with the same steps on a CPU, I found that the program takes an unreasonable amount of time (> 1 day) to finish when running with XLA, whereas it finishes running in < 1 min without XLA.

cheshire commented 2 days ago
  1. XLA:CPU is notoriously slow to compile, yes, this is being fixed. Overall XLA:CPU is very much work-in-progress.
  2. On GPU, it's crashing with the "dynamic padder" assertion, which tries to compile dynamic shapes from TF, and then verify their equality at runtime (and it's saying they are not equal). You could try to run with XLA_FLAGS=--xla_gpu_shape_checks=none and check the result, but the underlying issue here is in the TF2XLA lowering logic, not in XLA itself.
othakkar commented 1 day ago

@cheshire thanks for your response. FYI, just a minor correction to the flag: Setting XLA_FLAGS='--xla_gpu_shape_checks="IGNORE"' worked on GPU.

cheshire commented 19 hours ago

It could be hiding bugs - I'd double check the numerics vs. non-XLA case.