openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.55k stars 394 forks source link

[PJRT:GPU] Propagate arg and result info from MLIR to XLA Compile method #14900

Closed jaro-sevcik closed 1 month ago

jaro-sevcik commented 1 month ago

In MLIR flavor of the PjRtStreamExecutorClient::Compile method, we now transfer the argument layouts and result layout from MLIR code to compile options.

If compile options already specified argument layouts, we ignore layouts from MLIR.

We also make sure that the argument/result layouts are preserved when SPMD needs to canonicalize layouts after resharding parameters and/or layouts.

jaro-sevcik commented 1 month ago

This is a simpler version of https://github.com/openxla/xla/pull/14807.

The main functional difference from PR#14807 is that this version also preserves argument layouts through layout canonicalization callback (even though parameter layout canonicalization seems to be disabled for GPU in OSS XLA).

[This also fixes OSS build of se_gpu_pjrt_client_test.cc, see https://github.com/openxla/xla/pull/14899]

derdrdirk commented 1 month ago

third_party/tensorflow/compiler/xla/pjrt/gpu:se_gpu_pjrt_client_test is failing. Specifically the newly added MlirResultHostMemorySpaceIsSetInHloWithShardingPropagation test. Could you please fix the test?

[ RUN      ] StreamExecutorGpuClientTest.MlirResultHostMemorySpaceIsSetInHloWithShardingPropagation
[third_party/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc:1392](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/pjrt/gpu/se_gpu_pjrt_client_test.cc?l=1392&ws=dirkh/13446&snapshot=9): Failure
Value of: _status_or_value98.status().ok()
  Actual: false
Expected: true
INVALID_ARGUMENT: No matching device found for device_id 2
=== Source Location Trace: === 
[third_party/tensorflow/compiler/xla/util.h:281](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/util.h?l=281&ws=dirkh/13446&snapshot=9)
[third_party/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h:285](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.h?l=285&ws=dirkh/13446&snapshot=9)
[third_party/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:3441](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc?l=3441&ws=dirkh/13446&snapshot=9)
[third_party/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc:3481](https://cs.corp.google.com/piper///depot/google3/third_party/tensorflow/compiler/xla/pjrt/pjrt_stream_executor_client.cc?l=3481&ws=dirkh/13446&snapshot=9)

Stack trace:
0x7ffa9883f335: xla::(anonymous namespace)::StreamExecutorGpuClientTest_MlirResultHostMemorySpaceIsSetInHloWithShardingPropagation_Test::TestBody() @ ??:??
0x7ff88f738e62: testing::Test::Run() @ ??:??
0x7ff88f739ef9: testing::TestInfo::Run() @ ??:??
... Google Test internal frames ...

[  FAILED  ] StreamExecutorGpuClientTest.MlirResultHostMemorySpaceIsSetInHloWithShardingPropagation (1178 ms)