jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.27k stars 2.78k forks source link

XLA Check failed: common_utilization <= producer_output_utilization #17730

Open pl-fuchs opened 1 year ago

pl-fuchs commented 1 year ago

Description

When trying to run a longer algorithm, the execution fails with an error message without a more precise indication of where in the code the issue occurred:

F external/xla/xla/service/gpu/gpu_performance_model.cc:119] Check failed: common_utilization <= producer_output_utilization (500.867 vs. 500.867)

This error only occurs in some use cases of the algorithm, and slightly changing parameters such as iterations and batch size sometimes permit it.

What jax/jaxlib version are you using?

jax v0.4.9, jaxlib v0.4.9+cuda12.cudnn88

Which accelerator(s) are you using?

GPU

Additional system info

python 3.8

NVIDIA GPU info

+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.105.17   Driver Version: 525.105.17   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+

I am running the code with disabled parallel compilation:

TF_USE_NVLINK_FOR_PARALLEL_COMPILATION=0
hawkinsp commented 1 year ago

Well, that sounds like an XLA bug.

First, can you try with the latest jaxlib release (0.4.16)? The bug may already be fixed, so this is the first thing to try. You will need to update your Python version to 3.9 or newer to do this.

If that doesn't work, can you please provide instructions to reproduce? If it's hard to do that, one way is to provide an HLO dump from XLA, which you can get by setting

XLA_FLAGS=--xla_dump_to=/somewhere and JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0, running your script, and zip up and attach the output of /somewhere to this issue.

hawkinsp commented 1 year ago

Any updates? Can you share instructions to reproduce?

pl-fuchs commented 1 year ago

I upgraded the jax and jaxlib versions, but the error persists.

Unfortunately, I could not track the error to a specific part of the code. However, I did the steps you described and attached the dump.

xla_dump_part_1.tar.gz xla_dump_part_2.tar.gz

S-Roecken commented 11 months ago

Hi Peter Hawkins,

Do you have any updates on this issue?

Best Sebastien

ghcollin commented 10 months ago

I've run into the same error

Description

(jax) -bash-4.2$ JAX_COMPILER_DETAILED_LOGGING_MIN_OPS=0 XLA_FLAGS=--xla_dump_to=/tmp/xladump PYTHONPATH="./" python tests/tests.py
.F1127 20:16:06.642083   26806 gpu_performance_model.cc:358] Check failed: common_utilization <= producer_output_utilization (2.2 vs. 2.2) 
*** Check failure stack trace: ***
    @     0x7f901feca23e  absl::lts_20230802::log_internal::LogMessage::Flush()
    @     0x7f901feca349  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x7f90178ba3f4  xla::gpu::GpuPerformanceModel::ProducerInputAccessTime()
    @     0x7f9019f9f203  xla::gpu::GpuPerformanceModel::EstimateRunTimes()
    @     0x7f9019f8b521  xla::gpu::FusionInstructionMerger::ShouldFuse()
    @     0x7f9019f8cc76  xla::gpu::FusionInstructionMerger::Run()
    @     0x7f9019f8d4f4  xla::gpu::FusionMerger::Run()
    @     0x7f901da2daf5  xla::HloPassPipeline::RunPassesInternal<>()
    @     0x7f901da2e8e7  xla::HloPassPipeline::Run()
    @     0x7f9018e54c71  xla::HloPassInterface::Run()
    @     0x7f9018e6a5dd  xla::gpu::GpuCompiler::OptimizeHloModule()
    @     0x7f9018e6ee61  xla::gpu::GpuCompiler::RunHloPasses()
    @     0x7f9018d8d3a9  xla::Service::BuildExecutable()
    @     0x7f9018b472ad  xla::LocalService::CompileExecutables()
    @     0x7f9018b41f82  xla::LocalClient::Compile()
    @     0x7f9018b00d7c  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f9018adcc9f  xla::StreamExecutorGpuClient::Compile()
    @     0x7f9018b1400a  xla::PjRtStreamExecutorClient::Compile()
    @     0x7f9018a2e66f  xla::ifrt::PjRtLoadedExecutable::Create()
    @     0x7f9018a24d14  xla::ifrt::PjRtCompiler::Compile()
    @     0x7f9017fa9a52  xla::PyClient::Compile()
    @     0x7f9017cdb0b3  pybind11::detail::argument_loader<>::call_impl<>()
    @     0x7f9017cdb560  pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN()
    @     0x7f9017c90768  pybind11::cpp_function::dispatcher()
    @           0x525d17  cfunction_call
Aborted (core dumped)

The code in question is quite heavy in integer arithmetic, which may be part of the problem. You can find it here

xladump.tar.gz

Jax version

Jax/Jaxlib: jax-0.4.20 jaxlib-0.4.20+cuda12.cudnn89

System info

Using GPU on an RTX 4090. My code runs successfully on CPU.

Python 3.11.4, installed through anaconda

+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 530.30.02              Driver Version: 530.30.02    CUDA Version: 12.1     |
|-----------------------------------------+----------------------+----------------------+
m1balcerak commented 8 months ago

Same problem here. Running on GPU. Any solutions ? Depends on the hyper parameters.

sjrothfuss commented 7 months ago

EDIT: Originally thought I had the same problem but it looks like I'm failing a different XLA check so opened a separate issue. https://github.com/google/jax/issues/20024

dbezgin commented 7 months ago

Any updates on this?

Running into a similar error on A6000 GPU:

F external/xla/xla/service/gpu/model/gpu_performance_model.cc:540] Check failed: common_utilization <= producer_output_utilization (1.4 vs. 1.4)

jax : 0.4.24 jaxlib : 0.4.24+cuda12.cudnn89 cuda installed via pip wheels

F0320 13:19:47.472121 32626 gpu_performance_model.cc:358] Check failed: common_utilization <= producer_output_utilization (1.4 vs. 1.4) Check failure stack trace: @ 0x7fd84c68423e absl::lts_20230802::log_internal::LogMessage::Flush() @ 0x7fd84c684349 absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal() @ 0x7fd8440743f4 xla::gpu::GpuPerformanceModel::ProducerInputAccessTime() @ 0x7fd846759203 xla::gpu::GpuPerformanceModel::EstimateRunTimes() @ 0x7fd84674e678 xla::gpu::GpuMultiOutputFusion::DoMultiOutputFusion() @ 0x7fd8467503cc xla::gpu::GpuMultiOutputFusion::Run() @ 0x7fd84a1e7af5 xla::HloPassPipeline::RunPassesInternal<>() @ 0x7fd84a1e88e7 xla::HloPassPipeline::Run() @ 0x7fd84560ec71 xla::HloPassInterface::Run() @ 0x7fd8456245dd xla::gpu::GpuCompiler::OptimizeHloModule() @ 0x7fd845628e61 xla::gpu::GpuCompiler::RunHloPasses() @ 0x7fd8455473a9 xla::Service::BuildExecutable() @ 0x7fd8453012ad xla::LocalService::CompileExecutables() @ 0x7fd8452fbf82 xla::LocalClient::Compile() @ 0x7fd8452bad7c xla::PjRtStreamExecutorClient::Compile() @ 0x7fd845296c9f xla::StreamExecutorGpuClient::Compile() @ 0x7fd8452ce00a xla::PjRtStreamExecutorClient::Compile() @ 0x7fd8451e866f xla::ifrt::PjRtLoadedExecutable::Create() @ 0x7fd8451ded14 xla::ifrt::PjRtCompiler::Compile() @ 0x7fd844763a52 xla::PyClient::Compile() @ 0x7fd8444950b3 pybind11::detail::argument_loader<>::call_impl<>() @ 0x7fd844495560 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() @ 0x7fd84444a768 pybind11::cpp_function::dispatcher() @ 0x525d17 cfunction_call

pkarnakov commented 4 months ago

Same error with jax and jaxlib version 0.4.28. I found that replacing one line in my code

return 0.5**(epoch / period) if period else 1

with this

return jax.numpy.where(period, 0.5**(epoch / period), 1)

fixes the error. EDIT: Actually the error is intermittent, it sometimes appears even with the "fix".