Closed mars1248 closed 8 months ago
@vanbasten23 can you take a look when you have some cycle?
@JackCaoG I created a cuda version of the unit test from pjrt_stream_executor_client_test.cc that reproduced the problem. And I fixed this bug by creating a local execute method.https://github.com/openxla/xla/pull/8008 Could you check it for me?
pjrt_stream_executor_client_cuda_test.cc
/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "xla/pjrt/pjrt_stream_executor_client.h"
#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <iostream>
#include <gmock/gmock.h>
#include "absl/functional/any_invocable.h"
#include "absl/synchronization/mutex.h"
#include "xla/client/client_library.h"
#include "xla/client/xla_builder.h"
#include "xla/literal.h"
#include "xla/literal_comparison.h"
#include "xla/literal_util.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/service/platform_util.h"
#include "xla/shape_util.h"
#include "xla/test.h"
#include "xla/xla_data.pb.h"
#include "tsl/concurrency/async_value_ref.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"
namespace xla {
namespace {
xla::StatusOr<std::unique_ptr<PjRtStreamExecutorClient>> GetClient() {
LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie();
TF_ASSIGN_OR_RETURN(se::Platform * platform,
PlatformUtil::GetPlatform("CUDA"));
se::StreamExecutorConfig config;
config.ordinal = 0;
TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
platform->GetExecutor(config));
auto device_state = std::make_unique<LocalDeviceState>(
executor, local_client, LocalDeviceState::kSynchronous,
/*max_inflight_computations=*/32,
/*allow_event_reuse=*/false, /*use_callback_stream=*/false);
auto device = std::make_unique<PjRtStreamExecutorDevice>(
0, std::move(device_state), "cuda");
// auto device2 = std::make_unique<PjRtStreamExecutorDevice>(
// 1, std::move(device_state), "cuda");
std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
devices.emplace_back(std::move(device));
// devices.emplace_back(std::move(device2));
return std::make_unique<PjRtStreamExecutorClient>(
"cuda", local_client, std::move(devices), /*process_index=*/0,
/*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
/*should_stage_host_to_device_transfers=*/false,
/*gpu_run_options=*/nullptr);
}
StatusOr<std::unique_ptr<PjRtLoadedExecutable>> ToyExecutable(
PjRtStreamExecutorClient& client, Shape shape,
absl::AnyInvocable<void(XlaBuilder&)> set_up_aliases) {
CompileOptions compile_options;
XlaBuilder builder("Add");
auto a = Parameter(&builder, 0, shape, "a");
auto b = Parameter(&builder, 1, shape, "b");
auto c = Add(a, b);
auto d = Add(c, c);
Tuple(&builder, {c, d});
set_up_aliases(builder);
TF_ASSIGN_OR_RETURN(auto computation,
builder.Build(/*remove_dynamic_dimensions=*/true));
TF_ASSIGN_OR_RETURN(auto executable,
client.Compile(computation, compile_options));
return executable;
}
Status ExecuteWithSameInputBuffer(
absl::AnyInvocable<void(XlaBuilder&)> set_up_aliases) {
auto shape = xla::ShapeUtil::MakeScalarShape(xla::F32);
TF_ASSIGN_OR_RETURN(auto client, GetClient());
TF_ASSIGN_OR_RETURN(auto* device0, client->LookupDevice(0));
TF_ASSIGN_OR_RETURN(auto buffer,
client->CreateUninitializedBuffer(shape, device0));
TF_ASSIGN_OR_RETURN(auto executable,
ToyExecutable(*client, shape, std::move(set_up_aliases)));
std::optional<std::vector<xla::PjRtFuture<xla::Status>>> returned_futures(
1);
auto ans = executable->Execute({{buffer.get(), buffer.get()}}, /*options=*/{}, returned_futures);
(*returned_futures)[0].OnReady(
std::move([](xla::Status unused) mutable {
VLOG(0) << "ExecuteReplicated returned_future->OnReady finished";
}));
return ans.status();
}
Status LocalExecute(
absl::AnyInvocable<void(XlaBuilder&)> set_up_aliases) {
auto shape = xla::ShapeUtil::MakeScalarShape(xla::F32);
TF_ASSIGN_OR_RETURN(auto client, GetClient());
TF_ASSIGN_OR_RETURN(auto* device0, client->LookupDevice(0));
TF_ASSIGN_OR_RETURN(auto buffer,
client->CreateUninitializedBuffer(shape, device0));
std::optional<xla::PjRtFuture<xla::Status>> returned_future;
TF_ASSIGN_OR_RETURN(auto executable,
ToyExecutable(*client, shape, std::move(set_up_aliases)));
auto ans = executable->ExecuteLocal({{buffer.get(), buffer.get()}}, /*options=*/{},
returned_future);
returned_future->OnReady(
std::move([](xla::Status unused) mutable {
VLOG(0) << "ExecuteReplicated returned_future->OnReady finished";
}));
return ans.status();
}
TEST(PjRtStreamExecutorClientTest, LocalExecute) {
// f(a, a)
auto status = LocalExecute([](XlaBuilder& builder) {});
ASSERT_TRUE(status.ok());
}
TEST(PjRtStreamExecutorClientTest, DonateSameBufferTwice) {
// f(a, a)
auto status = ExecuteWithSameInputBuffer([](XlaBuilder& builder) {});
ASSERT_TRUE(status.ok());
}
} // namespace
} // namespace xla
Yeah, we observed the same and agreed that this needs to be fixed. A side note though, if you comment out https://github.com/pytorch/xla/blob/0cd6f1052db2421b3a77b602aba351c607f33480/torch_xla/csrc/runtime/pjrt_computation_client.cc#L692-L698, the spmd can run correctly.
@vanbasten23 If you fix it in this way, will it affect other modules? And should we combine open_xla to solve this problem completely?
Hi @mars1248 , you're right. https://github.com/pytorch/xla/issues/6225#issuecomment-1875769577 just shows the execution is correct. We do need a permanent fix.
We created a POC, similar to how the users would use it:
import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
import numpy as np
# Enable XLA SPMD execution mode.
xr.use_spmd()
# Assuming 4 GPU devices; PJRT_DEVICE=GPU, NUM_GPU_DEVICES=4
num_devices = xr.global_runtime_device_count()
# mesh shape will be (2,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))
t1 = torch.randn(1, 128, device='cpu')
t2 = torch.randn(1, 128, device='cpu')
expected = t1 + t2
xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
xs.mark_sharding(xt1, mesh, (0, 1))
# check to see if the spmd computation used all devices
print(torch_xla._XLAC._get_xla_sharding_spec(xt1))
actual = (xt1 + xt2).cpu()
assert torch.allclose(expected, actual)
To make the SPMD work, we need a fix similar to the one in https://github.com/pytorch/xla/pull/6266 so that xr.global_runtime_device_count()
returns all GPU devices across the hosts. That would resolve the https://github.com/pytorch/xla/issues/5910 you linked.
Also, thanks for the fix https://github.com/openxla/xla/pull/8008. I'm thinking after we fix xr.global_runtime_device_count()
so that it returns all device count, then https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2830 will return a number greater than 1 (all GPU devices on the host), then it's the else branch https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2853 that will be executed. So shouldn't your pr fix the else branch or fix why the future becomes invalid?
@vanbasten23
Thank you for your reply. I think what you said is very reasonable. But instead of running it on a single machine, we are running a distributed training with torchrun and using spmd, like soXLA_USE_SPMD=1 GPU_NUM_DEVICES=2 PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node 2 test/spmd/test_train_spmd_imagenet.py --pjrt_distributed --fake_data --batch_size 16 --model=resnet50 --sharding=batch
. In this case, xr.global_runtime_device_count() would return a value other than 1, and each process would only have access to the local partition, so I proposed a fix like openxla/xla#8008, I don't know if the way we use it is reasonable?
In this case, xr.global_runtime_device_count() would return a value other than 1
Right, xr.global_runtime_device_count()
is expected to return all the count of all GPU devices across the hosts. You'd need a fix first such as https://github.com/pytorch/xla/pull/6266.
each process would only have access to the local partition
The partition may not be 1. With the fix, if you check the vlog at https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2840-L2843, you can see num_replicas=1 num_partitions=4 num_addressable_devices=4
. iiuc, the fix https://github.com/openxla/xla/pull/8008 seems to focus on https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2845 while I'm thinking if we should fix the other branch https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2852C11-L2852C11 or why the future becomes invalid.
@vanbasten23
Hello, I have located why there is a problem with the future and fix it in this pr
https://github.com/pytorch/xla/pull/6275
This std::optional<std::vector<xla::PjRtFuture<xla::Status>>> returned_futures(devices.size());
will cause the elements in returned_futures to be initialized with null values, which will cause an error on subsequent attempts to access returned_futures[0].So we need to create an empty, pre-allocated vector to fix the coredump problem while keeping performance in check.
@vanbasten23 hi,If you have time, could you help me to review this pr https://github.com/pytorch/xla/pull/6275
🐛 Bug
During multi-GPU training, using spmd optimization, coredum is triggered
To Reproduce
Steps to reproduce the behavior:
1.run
2.will get this core dump error message
Expected behavior
Environment
Additional context