tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
832 stars 202 forks source link

Very slow first query / XLA compilation with vit_b32_fe model #473

Closed sebastianlutter closed 2 years ago

sebastianlutter commented 2 years ago

After loading a Vision Transformer model (vit_b32_fe) from tfhub with 0.4.1 version of tensorflow-java the first query takes about 30 seconds to start. Reason is the graph is pre-compiled using XLA.

This is a performance issue I want to solve. Using 0.5.0-SNAPSHOT is not possible because it fails to load the SavedModel (see https://github.com/tensorflow/java/issues/472). Since I can load and run the model without issues using python tensorflow 2.7.1 or 2.9.0 I wonder why I get this issue in Java.

Are there any options (disable XLA jit or alike) that may help to avoid that the model is blocking for 30 seconds. Thanks for any help.

System information

Describe the current behavior

Describe the expected behavior

Code to reproduce the issue Provide a reproducible test case that is the bare minimum necessary to generate the problem.

try (SavedModelBundle savedModel = SavedModelBundle.loader(graphFile.toString()).withTags(new String[]{"serve"}).load()) {
    // warm up query 
    long start = System.currentTimeMillis();
    log.info("Doing warm up query with tensorflow model");
    try (TFloat32 xTensor = TFloat32.tensorOf(NdArrays.ofFloats(Shape.of(1,244,244,3)));
          TFloat32 zTensor = (TFloat32) savedModel
                                .call(Collections.singletonMap("inputs", xTensor))
                                .get("output_0")) {
        long end = System.currentTimeMillis();
        log.info("Successfully warmed up tensorflow model, took "+(end-start)+"ms");

    }
};
warmUpThread.start();

Other info / logs

2022-09-14 12:14:44.843292: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:43] Reading SavedModel from: /tmp/pxl_14542487053289680376
2022-09-14 12:14:44.888522: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:107] Reading meta graph with tags { serve }
2022-09-14 12:14:44.888626: I external/org_tensorflow/tensorflow/cc/saved_model/reader.cc:148] Reading SavedModel debug info (if present) from: /tmp/pxl_14542487053289680376
2022-09-14 12:14:44.888682: I external/org_tensorflow/tensorflow/core/platform/cpu_feature_guard.cc:151] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations:  AVX2 FMA
To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.
2022-09-14 12:14:45.050174: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:228] Restoring SavedModel bundle.
2022-09-14 12:14:45.593306: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:212] Running initialization op on SavedModel bundle at path: /tmp/pxl_14542487053289680376
2022-09-14 12:14:45.877830: I external/org_tensorflow/tensorflow/cc/saved_model/loader.cc:301] SavedModel load for tags { serve }; Status: success: OK. Took 1034540 microseconds.
sɛt 14, 2022 7:44:46 PM de.pixolution.process.module.tf2.SavedModelEmbeddings$1 run
INFO: Doing warm up query with tensorflow model
2022-09-14 12:14:47.213119: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:171] XLA service 0x7f0de8014540 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2022-09-14 12:14:47.213143: I external/org_tensorflow/tensorflow/compiler/xla/service/service.cc:179]   StreamExecutor device (0): Host, Default Version
2022-09-14 12:14:47.255383: I external/org_tensorflow/tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:237] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2022-09-14 12:15:18.095626: I external/org_tensorflow/tensorflow/compiler/jit/xla_compilation_cache.cc:351] Compiled cluster using XLA!  This line is logged at most once for the lifetime of the process.
sɛt 14, 2022 7:45:18 PM de.pixolution.process.module.tf2.SavedModelEmbeddings$1 run
INFO: Successfully warmed up tensorflow model, took 31899ms
Craigacp commented 2 years ago

We might be compiling the CUDA ops with a different set of targets than the python binaries to reduce our build times due to a lack of resources. That would cause it to recompile the ops for your GPU on startup.

What GPU, CUDA version & driver version are you using?

sebastianlutter commented 2 years ago

CUDA is not involved, I'm running only on CPU (Intel 12th generation i7). There are no nvidia/cuda driver or libs installed at all.

How are the builds done? Using https://github.com/tensorflow/java/blob/master/deploy.sh? (in docker like shown in release.sh)

EDIT: Found https://github.com/tensorflow/java/blob/master/CONTRIBUTING.md#building in the meantime

Craigacp commented 2 years ago

Ok. Our XLA support is a bit iffy due to some open bugs in the upstream TF XLA support, but I'm not an expert on the consequences thereof.

sebastianlutter commented 2 years ago

Can you please provide a settings.xml file with the basic mvn settings used to build the jar like 0.5.0-SNAPSHOT? Would be an good starting point for building a customized jar.

Craigacp commented 2 years ago

It should build out of the box in Maven terms, but getting bazel configured to compile TF correctly is always a pain.

sebastianlutter commented 2 years ago

I build a minimal code example in Python and Java and found out that I was wrong:

Tensorflow 2.7.1 needs about 30 seconds of XLA compiling in Java and in Python. In Tensorflow 2.9.1 (Python) the problem does not exist. Closing this issue, @Craigacp thanks for your help!