pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.45k stars 462 forks source link

multi gpu training use spmd core dump #6225

Closed mars1248 closed 8 months ago

mars1248 commented 9 months ago

🐛 Bug

During multi-GPU training, using spmd optimization, coredum is triggered

To Reproduce

Steps to reproduce the behavior:

1.run

XLA_USE_SPMD=1 GPU_NUM_DEVICES=2 PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node 2 test/spmd/test_train_spmd_imagenet.py --pjrt_distributed --fake_data --batch_size 16 --model=resnet50 --sharding=batch

2.will get this core dump error message

F0000 00:00:1703077828.811772  117876 pjrt_future.h:229] Check failed: IsValid() 
*** Check failure stack trace: ***
    @     0x7f3bb1913669  absl::lts_20230802::log_internal::LogMessageFatal::~LogMessageFatal()
    @     0x56269642e7b2  xla::PjRtFuture<>::OnReady()
    @     0x56269642ed4d  xla::(anonymous namespace)::ExecuteWithSameInputBuffer()
    @     0x56269643b99a  xla::(anonymous namespace)::PjRtStreamExecutorClientTest_DonateSameBufferTwice_Test::TestBody()
    @     0x7f3bb242dfff  testing::internal::HandleExceptionsInMethodIfSupported<>()
    @     0x7f3bb242e246  testing::Test::Run()
    @     0x7f3bb242e5bd  testing::TestInfo::Run()
    @     0x7f3bb242ee01  testing::TestSuite::Run()
    @     0x7f3bb24348da  testing::internal::UnitTestImpl::RunAllTests()
    @     0x7f3bb242e857  testing::internal::HandleExceptionsInMethodIfSupported<>()
    @     0x7f3bb242e9f8  testing::UnitTest::Run()
    @     0x7f3bb24670e1  main
    @     0x7f3bb06d7401  __libc_start_main

Expected behavior

Environment

Additional context

JackCaoG commented 9 months ago

@vanbasten23 can you take a look when you have some cycle?

mars1248 commented 9 months ago

@JackCaoG I created a cuda version of the unit test from pjrt_stream_executor_client_test.cc that reproduced the problem. And I fixed this bug by creating a local execute method.https://github.com/openxla/xla/pull/8008 Could you check it for me?

pjrt_stream_executor_client_cuda_test.cc

/* Copyright 2022 The TensorFlow Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "xla/pjrt/pjrt_stream_executor_client.h"

#include <memory>
#include <string>
#include <utility>
#include <vector>
#include <iostream>

#include <gmock/gmock.h>
#include "absl/functional/any_invocable.h"
#include "absl/synchronization/mutex.h"
#include "xla/client/client_library.h"
#include "xla/client/xla_builder.h"
#include "xla/literal.h"
#include "xla/literal_comparison.h"
#include "xla/literal_util.h"
#include "xla/pjrt/pjrt_client.h"
#include "xla/pjrt/pjrt_future.h"
#include "xla/service/platform_util.h"
#include "xla/shape_util.h"
#include "xla/test.h"
#include "xla/xla_data.pb.h"
#include "tsl/concurrency/async_value_ref.h"
#include "tsl/lib/core/status_test_util.h"
#include "tsl/platform/status.h"
#include "tsl/platform/statusor.h"

namespace xla {
namespace {

xla::StatusOr<std::unique_ptr<PjRtStreamExecutorClient>> GetClient() {
  LocalClient* local_client = xla::ClientLibrary::LocalClientOrDie();
  TF_ASSIGN_OR_RETURN(se::Platform * platform,
                      PlatformUtil::GetPlatform("CUDA"));
  se::StreamExecutorConfig config;
  config.ordinal = 0;
  TF_ASSIGN_OR_RETURN(se::StreamExecutor * executor,
                      platform->GetExecutor(config));
  auto device_state = std::make_unique<LocalDeviceState>(
      executor, local_client, LocalDeviceState::kSynchronous,
      /*max_inflight_computations=*/32,
      /*allow_event_reuse=*/false, /*use_callback_stream=*/false);
  auto device = std::make_unique<PjRtStreamExecutorDevice>(
      0, std::move(device_state), "cuda");
  // auto device2 = std::make_unique<PjRtStreamExecutorDevice>(
  //   1, std::move(device_state), "cuda");
  std::vector<std::unique_ptr<PjRtStreamExecutorDevice>> devices;
  devices.emplace_back(std::move(device));
  // devices.emplace_back(std::move(device2));
  return std::make_unique<PjRtStreamExecutorClient>(
      "cuda", local_client, std::move(devices), /*process_index=*/0,
      /*allocator=*/nullptr, /*host_memory_allocator=*/nullptr,
      /*should_stage_host_to_device_transfers=*/false,
      /*gpu_run_options=*/nullptr);
}

StatusOr<std::unique_ptr<PjRtLoadedExecutable>> ToyExecutable(
    PjRtStreamExecutorClient& client, Shape shape,
    absl::AnyInvocable<void(XlaBuilder&)> set_up_aliases) {
  CompileOptions compile_options;
  XlaBuilder builder("Add");
  auto a = Parameter(&builder, 0, shape, "a");
  auto b = Parameter(&builder, 1, shape, "b");
  auto c = Add(a, b);
  auto d = Add(c, c);
  Tuple(&builder, {c, d});
  set_up_aliases(builder);
  TF_ASSIGN_OR_RETURN(auto computation,
                      builder.Build(/*remove_dynamic_dimensions=*/true));
  TF_ASSIGN_OR_RETURN(auto executable,
                      client.Compile(computation, compile_options));
  return executable;
}

Status ExecuteWithSameInputBuffer(
    absl::AnyInvocable<void(XlaBuilder&)> set_up_aliases) {
  auto shape = xla::ShapeUtil::MakeScalarShape(xla::F32);
  TF_ASSIGN_OR_RETURN(auto client, GetClient());
  TF_ASSIGN_OR_RETURN(auto* device0, client->LookupDevice(0));
  TF_ASSIGN_OR_RETURN(auto buffer,
                      client->CreateUninitializedBuffer(shape, device0));
  TF_ASSIGN_OR_RETURN(auto executable,
                      ToyExecutable(*client, shape, std::move(set_up_aliases)));
  std::optional<std::vector<xla::PjRtFuture<xla::Status>>> returned_futures(
      1);
  auto ans = executable->Execute({{buffer.get(), buffer.get()}}, /*options=*/{}, returned_futures);
  (*returned_futures)[0].OnReady(
    std::move([](xla::Status unused) mutable {
      VLOG(0) << "ExecuteReplicated returned_future->OnReady finished";
    }));
  return ans.status();
}

Status LocalExecute(
    absl::AnyInvocable<void(XlaBuilder&)> set_up_aliases) {
  auto shape = xla::ShapeUtil::MakeScalarShape(xla::F32);
  TF_ASSIGN_OR_RETURN(auto client, GetClient());
  TF_ASSIGN_OR_RETURN(auto* device0, client->LookupDevice(0));
  TF_ASSIGN_OR_RETURN(auto buffer,
                      client->CreateUninitializedBuffer(shape, device0));
  std::optional<xla::PjRtFuture<xla::Status>> returned_future;
  TF_ASSIGN_OR_RETURN(auto executable,
                      ToyExecutable(*client, shape, std::move(set_up_aliases)));
  auto ans = executable->ExecuteLocal({{buffer.get(), buffer.get()}}, /*options=*/{},
                            returned_future);

  returned_future->OnReady(
      std::move([](xla::Status unused) mutable {
        VLOG(0) << "ExecuteReplicated returned_future->OnReady finished";
      }));
  return ans.status();
}

TEST(PjRtStreamExecutorClientTest, LocalExecute) {
  // f(a, a)
  auto status = LocalExecute([](XlaBuilder& builder) {});
  ASSERT_TRUE(status.ok());
}

TEST(PjRtStreamExecutorClientTest, DonateSameBufferTwice) {
  // f(a, a)
  auto status = ExecuteWithSameInputBuffer([](XlaBuilder& builder) {});
  ASSERT_TRUE(status.ok());
}

}  // namespace
}  // namespace xla
vanbasten23 commented 8 months ago

Yeah, we observed the same and agreed that this needs to be fixed. A side note though, if you comment out https://github.com/pytorch/xla/blob/0cd6f1052db2421b3a77b602aba351c607f33480/torch_xla/csrc/runtime/pjrt_computation_client.cc#L692-L698, the spmd can run correctly.

mars1248 commented 8 months ago

@vanbasten23 If you fix it in this way, will it affect other modules? And should we combine open_xla to solve this problem completely?

vanbasten23 commented 8 months ago

Hi @mars1248 , you're right. https://github.com/pytorch/xla/issues/6225#issuecomment-1875769577 just shows the execution is correct. We do need a permanent fix.

We created a POC, similar to how the users would use it:

import torch
import torch_xla
import torch_xla.runtime as xr
import torch_xla.core.xla_model as xm
import torch_xla.experimental.xla_sharding as xs
from torch_xla.experimental.xla_sharding import Mesh
import numpy as np

# Enable XLA SPMD execution mode.
xr.use_spmd()

# Assuming 4 GPU devices; PJRT_DEVICE=GPU, NUM_GPU_DEVICES=4
num_devices = xr.global_runtime_device_count()
# mesh shape will be (2,2) in this example
mesh_shape = (num_devices // 2, 2)
device_ids = np.array(range(num_devices))
# axis_names 'x' nad 'y' are optional
mesh = Mesh(device_ids, mesh_shape, ('x', 'y'))

t1 = torch.randn(1, 128, device='cpu')
t2 = torch.randn(1, 128, device='cpu')
expected = t1 + t2

xt1 = t1.to(xm.xla_device())
xt2 = t2.to(xm.xla_device())
xs.mark_sharding(xt1, mesh, (0, 1))

# check to see if the spmd computation used all devices
print(torch_xla._XLAC._get_xla_sharding_spec(xt1))

actual = (xt1 + xt2).cpu()
assert torch.allclose(expected, actual)

To make the SPMD work, we need a fix similar to the one in https://github.com/pytorch/xla/pull/6266 so that xr.global_runtime_device_count() returns all GPU devices across the hosts. That would resolve the https://github.com/pytorch/xla/issues/5910 you linked.

Also, thanks for the fix https://github.com/openxla/xla/pull/8008. I'm thinking after we fix xr.global_runtime_device_count() so that it returns all device count, then https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2830 will return a number greater than 1 (all GPU devices on the host), then it's the else branch https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2853 that will be executed. So shouldn't your pr fix the else branch or fix why the future becomes invalid?

mars1248 commented 8 months ago

@vanbasten23 Thank you for your reply. I think what you said is very reasonable. But instead of running it on a single machine, we are running a distributed training with torchrun and using spmd, like soXLA_USE_SPMD=1 GPU_NUM_DEVICES=2 PJRT_DEVICE=CUDA torchrun --nnodes 1 --nproc-per-node 2 test/spmd/test_train_spmd_imagenet.py --pjrt_distributed --fake_data --batch_size 16 --model=resnet50 --sharding=batch. In this case, xr.global_runtime_device_count() would return a value other than 1, and each process would only have access to the local partition, so I proposed a fix like openxla/xla#8008, I don't know if the way we use it is reasonable?

vanbasten23 commented 8 months ago

In this case, xr.global_runtime_device_count() would return a value other than 1

Right, xr.global_runtime_device_count() is expected to return all the count of all GPU devices across the hosts. You'd need a fix first such as https://github.com/pytorch/xla/pull/6266.

each process would only have access to the local partition

The partition may not be 1. With the fix, if you check the vlog at https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2840-L2843, you can see num_replicas=1 num_partitions=4 num_addressable_devices=4. iiuc, the fix https://github.com/openxla/xla/pull/8008 seems to focus on https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2845 while I'm thinking if we should fix the other branch https://github.com/openxla/xla/blob/3ee0fa53876e8975cae351e24787745932f75867/xla/pjrt/pjrt_stream_executor_client.cc#L2852C11-L2852C11 or why the future becomes invalid.

mars1248 commented 8 months ago

@vanbasten23 Hello, I have located why there is a problem with the future and fix it in this pr https://github.com/pytorch/xla/pull/6275 This std::optional<std::vector<xla::PjRtFuture<xla::Status>>> returned_futures(devices.size());will cause the elements in returned_futures to be initialized with null values, which will cause an error on subsequent attempts to access returned_futures[0].So we need to create an empty, pre-allocated vector to fix the coredump problem while keeping performance in check.

mars1248 commented 8 months ago

@vanbasten23 hi,If you have time, could you help me to review this pr https://github.com/pytorch/xla/pull/6275