pytorch / xla

Enabling PyTorch on XLA Devices (e.g. Google TPU)
https://pytorch.org/xla
Other
2.38k stars 427 forks source link

Problem with mesh shape in HybridMesh on TPU #7102

Open manh3152924 opened 1 month ago

manh3152924 commented 1 month ago

❓ Questions and Help

I recived error when try create sqmd mesh on kaggle notebook when flow Huggingface optimum-tpu

import os
import numpy as np

import torch_xla
import torch_xla.core.xla_model as xm
import torch_xla.distributed.xla_multiprocessing as xmp
from torch_xla.distributed.fsdp import checkpoint_module
from torch_xla.distributed.fsdp.utils import apply_xla_patch_to_nn_linear
import torch_xla.distributed.parallel_loader as pl
import torch_xla.core.xla_env_vars as xenv
import torch_xla.debug.metrics as met
import torch_xla.distributed.spmd.xla_sharding as xs
from torch_xla.distributed.spmd.xla_sharding import Mesh, HybridMesh
from torch_xla.distributed.spmd.xla_sharded_tensor import XLAShardedTensor
import torch_xla.runtime as xr
xr.use_spmd()

os.environ['USE_TORCH'] = 'True'
os.environ["PJRT_DEVICE"] = "TPU"
os.environ['TPU_NUM_DEVICES'] = '8'
os.environ[xenv.TPU_VISIBLE_CHIPS] = '0,1,2,3'
os.environ[xenv.TPU_PROCESS_BOUNDS] = '1,1,1'
num_devices = xr.global_runtime_device_count() # 8
model_axis = 1
assert xr.device_type() == 'TPU', "Only TPU is supported"
#     dcn_axis = model_args.spmd_dcn_parallelism # 1
dcn_axis = 1
data_axis = num_devices // model_axis // dcn_axis
# mesh data setup
ici_mesh_shape = (1, data_axis, model_axis)
dcn_mesh_shape = (dcn_axis, 1, 1)
axis_names=('dcn', 'data', 'model')
print('ici', ici_mesh_shape)
print('dcn', dcn_mesh_shape)
# Note that we do not pass the spmd_mesh to the model because it is not JSON-serializable.
spmd_mesh = HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=axis_names)

full error:

ici (1, 8, 1)
dcn (1, 1, 1)
---------------------------------------------------------------------------
NotImplementedError                       Traceback (most recent call last)
Cell In[28], line 41
     39 print('dcn', dcn_mesh_shape)
     40 # Note that we do not pass the spmd_mesh to the model because it is not JSON-serializable.
---> 41 spmd_mesh = HybridMesh(ici_mesh_shape=ici_mesh_shape, dcn_mesh_shape=dcn_mesh_shape, axis_names=axis_names)

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py:188, in HybridMesh.__init__(self, ici_mesh_shape, dcn_mesh_shape, axis_names)
    185   mesh = self._create_hybrid_device_mesh(self.ici_mesh_shape,
    186                                          self.dcn_mesh_shape)
    187 else:
--> 188   mesh = self._create_device_mesh(self.ici_mesh_shape)
    189 device_ids = mesh.flatten()
    190 super().__init__(device_ids, mesh_shape, axis_names)

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py:323, in HybridMesh._create_device_mesh(self, mesh_shape, devices)
    319   raise ValueError(
    320       f'Number of devices {len(devices)} must equal the product '
    321       f'of mesh_shape {mesh_shape}')
    322 physical_mesh = self._get_physical_tpu_mesh(devices)
--> 323 device_mesh, assignment = self._create_device_mesh_for_nd_torus(
    324     physical_mesh, mesh_shape)
    325 return device_mesh

File /usr/local/lib/python3.10/site-packages/torch_xla/distributed/spmd/xla_sharding.py:286, in HybridMesh._create_device_mesh_for_nd_torus(self, physical_mesh, mesh_shape)
    282   else:
    283     # If the num_axes for loop did not break, i.e. none of the candidates work
    284     # goto here with this while-else construct.
    285     if logical_axis_size > 1:
--> 286       raise NotImplementedError(
    287           'Failed to find assignment for logical_axis_index'
    288           f' {logical_axis_index} of size {logical_axis_size} with remaining'
    289           f' assignable mesh {assignable_physical_mesh}. The size of each'
    290           ' axis in your logical mesh must be equal to the product of'
    291           ' some subset of the physical mesh axis sizes. E.g logical mesh (4,'
    292           ' 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.'
    293       )
    294 # Flatten the assignment
    295 transpose: List[int] = []

NotImplementedError: Failed to find assignment for logical_axis_index 1 of size 8 with remaining assignable mesh [2, 2, 0]. The size of each axis in your logical mesh must be equal to the product of some subset of the physical mesh axis sizes. E.g logical mesh (4, 16) is compatible with physical mesh 4x4x4 since 4=4 and 16=4x4.

TPUv3-8 of kaggle have 8 cores(2x4) so I don't know why i get error. What problem? Thanks for your help!

JackCaoG commented 1 month ago

@jonb377 any ideas?

jonb377 commented 1 month ago

We depend on the device coordinates to determine the physical topology, and v3 will report the same coordinates for each core with a different core_on_chip:

# On v3-8
>>> xr.global_runtime_device_attributes()
[{'core_on_chip': 0, 'coords': [0, 0, 0], 'num_cores': 1, 'name': 'TPU:0'},
{'core_on_chip': 1, 'num_cores': 1, 'coords': [0, 0, 0], 'name': 'TPU:1'},
{'coords': [1, 0, 0], 'core_on_chip': 0, 'num_cores': 1, 'name': 'TPU:2'},
{'coords': [1, 0, 0], 'core_on_chip': 1, 'num_cores': 1, 'name': 'TPU:3'},
{'core_on_chip': 0, 'num_cores': 1, 'coords': [0, 1, 0], 'name': 'TPU:4'},
{'num_cores': 1, 'coords': [0, 1, 0], 'core_on_chip': 1, 'name': 'TPU:5'},
{'num_cores': 1, 'coords': [1, 1, 0], 'core_on_chip': 0, 'name': 'TPU:6'},
{'core_on_chip': 1, 'num_cores': 1, 'coords': [1, 1, 0], 'name': 'TPU:7'}]

It looks like we would need to special case for v2 and v3 like JAX does: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L545-L551

@manh3152924 Since you're running single-slice, can you try using Mesh instead of HybridMesh for now? HybridMesh is primarily for the multislice use case, which is supported on v4+.

manh3152924 commented 1 month ago

@jonb377 I used Mesh before and received not enough ram, you can check my kaggle in here. My purpose is try use all 8 core of tpu and high RAM of it, but seem it only use one core with 16gb ram. Are you have any suggestions?

jonb377 commented 1 month ago

I don't seem to have access to the kaggle notebook, could you please share that or a minimal repro?

mxmm2123 commented 1 month ago

@jonb377 Sorry, my mistake. It was public now. You can re-check it here

mxmm2123 commented 1 month ago

@jonb377, do you have any ideas?

jonb377 commented 1 month ago

Thanks for sharing, I've had a look and made the following observations:

  1. The runtime OOM is coming from the dataloader.
  2. Likely the issue: the program is mixing xmp.spawn with SPMD, which isn't supported. We should call train directly from the main process with xr.use_spmd().

I wasn't able to run a repro, but could you try invoking train directly instead of through xmp.spawn?

if __name__ == '__main__':
    print('Load')
    train_dataloader, tokenizer, FLAGS = get_dataset()
    xr.use_spmd()
    # With SPMD, we should let one process control all devices. Multiprocessing isn't fully supported.
    #xmp.spawn(train, args=(train_dataloader, tokenizer, FLAGS))
    train(train_dataloader, tokenizer, FLAGS)
mxmm2123 commented 1 month ago

@jonb377 I changed the code to direct call train in __main__ , then this continues OOM as you said it is problem with dataloader. But I don't clearly understand; I only use about 5k-10k/100k samples (it is actually small, code is run with 1k samples), it seems data does not partition all tpu core and then gets an OOM error.

jonb377 commented 1 month ago

Hey @mxmm2123, sorry for the delay - I realize the dataloader may be a red herring since we just rely on the dataloader to call mark_step.

It would be helpful to determine if the OOM is happening before or after tracing the first step, you can either do an HLO dump with the environment variable XLA_FLAGS=--xla_dump_to=/tmp/xla_dump or just add a print in the train loop:

        for step, data in enumerate(xla_train_loader):
            print(f'tracing {step=}')
            input_ids, labels, attention_mask = data.input_ids, data.labels, data.attention_mask
            optimizer.zero_grad()
            ...

Assuming the OOM is happening after tracing the first train step:

mxmm2123 commented 1 month ago

@jonb377 Thanks. As you recommended, it worked when I decreased the batch size to 32. Basically, the problem was solved after calling the train directly in the main and using a smaller batch size. Then, if xr.use_spmd() will control all devices in one process, is it necessary when we use DistributedSampler for it?

train_sampler = torch.utils.data.distributed.DistributedSampler(data_train,
                                                                    num_replicas=xm.xrt_world_size(), # Equal 1 when use xr.use_spmd()
                                                                    rank=xm.get_ordinal(), # 1
                                                                    shuffle=True) #this guy is responsible for distributing data across 8 cores

rng = torch.Generator().manual_seed(42)
training_loader = torch.utils.data.DataLoader(data_train,
                                              batch_size=FLAGS['BATCH_SIZE'],
                                              collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
                                              sampler=train_sampler,
                                              drop_last=True, generator=rng)

sharding_spec = xs.ShardingSpec(mesh, (('dcn', 'data'), None))
xla_train_loader = pl.MpDeviceLoader(training_loader,
                                     device = xm.xla_device(), # It equal 1 
                                    input_sharding=sharding_spec,
                                     device_prefetch_size=4
                                    )
jonb377 commented 1 month ago

Great! You can try increasing the batch size up to maybe 128. If you hit the runtime OOM again, you'll just need to decrease it.

With single-host SPMD, the main process will load the global batch and distribute it across all local devices, so you don't need the DistributedSampler. input_sharding in MpDeviceLoader is responsible for sharding the global batch across all devices.

mxmm2123 commented 1 month ago

@jonb377 Thank you, it's useful to me.