jax-ml / jax

Composable transformations of Python+NumPy programs: differentiate, vectorize, JIT to GPU/TPU, and more
http://jax.readthedocs.io/
Apache License 2.0
30.66k stars 2.83k forks source link

No invoker was registered for convolution forward, on Rocm devices. #18906

Open KegangWangCCNU opened 12 months ago

KegangWangCCNU commented 12 months ago

Description

When I use jax.pmap to execute a network containing convolutions, an error is reported:

model = nn.Conv(64, (3,))
# model = nn.Dense(64)
key1, key2 = jrd.split(jrd.PRNGKey(0), 2) 
x = jrd.uniform(key1, (32, 128, 3)) 
params = model.init(key2, x)  
def apply_model(ipt):
    return model.apply(params, ipt) 

devices = 4
jax.pmap(apply_model)(x.reshape(devices,x.shape[0]//devices,*x.shape[1:]))
MIOpen Error: /long_pathname_so_that_rpms_can_package_the_debug_info/src/extlibs/MLOpen/src/ocl/convolutionocl.cpp:456: No invoker was registered for convolution forward. Was find executed?
2023-12-10 16:39:58.274379: W external/xla/xla/service/gpu/runtime/support.cc:58] Intercepted XLA runtime error:
INTERNAL: Failed to enqueue convolution on stream: miopenStatusUnknownError
2023-12-10 16:39:58.274442: E external/xla/xla/pjrt/pjrt_stream_executor_client.cc:2716] Execution of replica 1 failed: INTERNAL: Failed to execute XLA Runtime executable: run time error: custom call 'xla.gpu.conv.forward' failed: Failed to enqueue convolution on stream: miopenStatusUnknownError; current tracing scope: cudnn-conv.1; current profiling annotation: XlaModule:#hlo_module=pmap_apply_model,program_id=462#.

Running the model on a single GPU poses no issues.

What jax/jaxlib version are you using?

0.4.22.dev20231209+ccc8b3f7a, 0.4.22.dev20231210

Which accelerator(s) are you using?

AMD GPU

Additional system info?

1.26.2 3.9.18 (main, Sep 11 2023, 13:41:44) [GCC 11.2.0] uname_result(system='Linux', node='ww-server', release='6.2.0-34-generic', version='#34~22.04.1-Ubuntu SMP PREEMPT_DYNAMIC Thu Sep 7 13:12:03 UTC 2', machine='x86_64')

NVIDIA GPU info

GPU[0] : Card series: Vega 20 [Radeon Pro VII/Radeon Instinct MI50 32GB] GPU[0] : Card model: 0x081e GPU[0] : Card vendor: Advanced Micro Devices, Inc. [AMD/ATI] GPU[0] : Card SKU: D1640600 GPU[1] : Card series: Vega 20 [Radeon Pro VII/Radeon Instinct MI50 32GB] GPU[1] : Card model: 0x081e GPU[1] : Card vendor: Advanced Micro Devices, Inc. [AMD/ATI] GPU[1] : Card SKU: D1640600 GPU[2] : Card series: Vega 20 [Radeon Pro VII/Radeon Instinct MI50 32GB] GPU[2] : Card model: 0x081e GPU[2] : Card vendor: Advanced Micro Devices, Inc. [AMD/ATI] GPU[2] : Card SKU: D1640600 GPU[3] : Card series: Vega 20 [Radeon Pro VII/Radeon Instinct MI50 32GB] GPU[3] : Card model: 0x081e GPU[3] : Card vendor: Advanced Micro Devices, Inc. [AMD/ATI] GPU[3] : Card SKU: D1640600

hawkinsp commented 11 months ago

@rahulbatra85

rahulbatra85 commented 11 months ago

@KegangWangCCNU Can you try running with this env var set TF_ROCM_USE_IMMEDIATE_MODE=1?

Since you are using jaxlib 0.4.22, I assume you built JAX yourself. Can you share where did you pull in XLA code from?

KegangWangCCNU commented 11 months ago

@rahulbatra85

Thank you, the issue has been temporarily resolved.

I have tried both the official and AMD branches of XLA, and they have the same problem. To my knowledge, this also occurs in earlier versions. https://github.com/google/jax/issues/14582

zhangyu0722 commented 11 months ago

@KegangWangCCNU How did you solve this problem? Because I am currently troubled by this problem

brettkoonce commented 8 months ago

See also https://github.com/google/jax/issues/14582.

rahulbatra85 commented 8 months ago

This should be fixed in latest release 0.4.25 https://github.com/ROCm/jax/releases/tag/jaxlib-v0.4.25. Shouldn't need to set environment variable manually.

rahulbatra85 commented 8 months ago

Confirmed fixed in 0.4.25 https://github.com/google/jax/issues/14582.

Please close this issue as well