openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.61k stars 409 forks source link

Support autosharding in OSS #7248

Open cheshire opened 10 months ago

cheshire commented 10 months ago

In order to support autosharding in OSS, we need to fix the non-determinism issues from using absl::flat_hash_map. The iteration order for StableHashMap in auto_sharding_strategy.h needs to become deterministic (internally it follows the insertion order, backed by linked hash map).

lausannel commented 3 months ago

Hi, @nluehr @cheshire I would like to inquire whether this issue has been resolved. I changed the ::absl::flat_hash_map<Key, Value>; to ::absl::btree_map<Key, Value>; to fix the iteration order of StableHashMap.

https://github.com/openxla/xla/blob/d06d46e3f83f897089bc2e0246cd439d9694c0c7/xla/hlo/experimental/auto_sharding/auto_sharding_strategy.h#L45

I've observed additional non-deterministic behavior during the autosharding pass in the OSS version of XLA. Could you explain why the solver parameter string is specifically restricted to PLATFORM_GOOGLE? Additionally, are there any other known instances of non-determinism associated with the autosharding pass that I should be aware of?

https://github.com/openxla/xla/blob/000657e838427efb9515224622d051c00fc7b62c/xla/hlo/experimental/auto_sharding/auto_sharding_solver.cc#L407-L422

For a simple DNN that consists of 3 MLP layers, I've noticed that each time I execute the auto sharding pass, the sharding strategy for the layers shows minor variations.

Code to reproduce ```python import torch_xla.runtime as xr from torch.distributed._tensor import DeviceMesh, distribute_module from torch_xla.distributed.spmd import auto_policy from torch import nn import torch_xla.core.xla_model as xm import torch import time import os import random import numpy as np import torch_xla import torch.distributed as dist import torch_xla.runtime as xr from torch_xla.distributed.spmd import ( # type:ignore[import] xla_distribute_module, ) import torch_xla.distributed.spmd as xs from torch_xla.distributed.spmd import Mesh import argparse def parse_args(): parser = argparse.ArgumentParser(description="Auto Sharding Test Arguments") parser.add_argument("--input_size", type=int, default=1000) parser.add_argument("--hidden_size1", type=int, default=20000) parser.add_argument("--hidden_size2", type=int, default=10000) parser.add_argument("--output_size", type=int, default=1) parser.add_argument("--add-manual-sharding", action="store_true") return parser.parse_args() args = parse_args() xr.use_spmd(auto=True) seed = 43 torch.manual_seed(seed) torch.cuda.manual_seed_all(seed) torch.manual_seed(seed) xm.set_rng_state(seed) os.environ['PJRT_DEVICE'] = os.getenv('PJRT_DEVICE', 'CUDA') os.environ['XLA_USE_SPMD'] = os.getenv('XLA_USE_SPMD', '1') os.environ['XLA_DISABLE_FUNCTIONALIZATION'] = os.getenv('XLA_DISABLE_FUNCTIONALIZATION', '1') num_devices = xr.global_runtime_device_count() print(f"{num_devices=}") assert num_devices == 4 mesh_shape = (4,1) device_ids = np.array(range(num_devices)) mesh = Mesh(device_ids, mesh_shape, ('data', 'model')) class MyModule(nn.Module): def __init__(self): super(MyModule, self).__init__() self.fc1 = nn.Linear(args.input_size, args.hidden_size1, bias=False) # 1000 * 20000 self.fc2 = nn.Linear(args.hidden_size1, args.hidden_size2, bias=False) # 20000 * 10000 self.fc3 = nn.Linear(args.hidden_size2, args.output_size, bias=False) # 10000 * 1 self.relu = nn.ReLU() def forward(self, x): x = self.relu(self.fc1(x)) x = self.relu(self.fc2(x)) x = self.fc3(x) return x device = xm.xla_device() device_count = xr.global_runtime_device_count() device_mesh = DeviceMesh("xla", list(range(device_count))) criteria = nn.MSELoss() model = MyModule().to(device) optimizer = torch.optim.SGD(model.parameters(), lr=0.01) for i in range(1): optimizer.zero_grad() input = torch.randn(100, args.input_size).to(device) label = torch.randn(100, args.output_size).to(device) output = model(input) loss = criteria(output, label) loss.backward() optimizer.step() xm.mark_step() print(f"{output=}") ```

For instance, on certain runs, both fc3 and fc1 are sharded across four devices, while on other occasions, only fc3 is sharded.

cheshire commented 2 months ago

@pratikfegade thoughts?

pratikfegade commented 2 months ago

I am not sure if I can see something that's obviously going wrong here. While there can be other sources of non-determinism as mentioned above, the solver is a big one. Could we verify that the solver is reaching completion and optimality in the above case? Meanwhile, I can run the OSS version to try to repro the non-determinism myself.

pratikfegade commented 2 months ago

@mmoffitt for visibility as well

mmoffitt commented 2 months ago

We should be able to remove those PLATFORM_GOOGLE guards ... that should definitely help with the determinism issues.

I'll attempt to update the code today.