Open manh3152924 opened 1 month ago
@jonb377 any ideas?
We depend on the device coordinates to determine the physical topology, and v3 will report the same coordinates for each core with a different core_on_chip
:
# On v3-8
>>> xr.global_runtime_device_attributes()
[{'core_on_chip': 0, 'coords': [0, 0, 0], 'num_cores': 1, 'name': 'TPU:0'},
{'core_on_chip': 1, 'num_cores': 1, 'coords': [0, 0, 0], 'name': 'TPU:1'},
{'coords': [1, 0, 0], 'core_on_chip': 0, 'num_cores': 1, 'name': 'TPU:2'},
{'coords': [1, 0, 0], 'core_on_chip': 1, 'num_cores': 1, 'name': 'TPU:3'},
{'core_on_chip': 0, 'num_cores': 1, 'coords': [0, 1, 0], 'name': 'TPU:4'},
{'num_cores': 1, 'coords': [0, 1, 0], 'core_on_chip': 1, 'name': 'TPU:5'},
{'num_cores': 1, 'coords': [1, 1, 0], 'core_on_chip': 0, 'name': 'TPU:6'},
{'core_on_chip': 1, 'num_cores': 1, 'coords': [1, 1, 0], 'name': 'TPU:7'}]
It looks like we would need to special case for v2 and v3 like JAX does: https://github.com/google/jax/blob/main/jax/experimental/mesh_utils.py#L545-L551
@manh3152924 Since you're running single-slice, can you try using Mesh
instead of HybridMesh
for now? HybridMesh
is primarily for the multislice use case, which is supported on v4+.
@jonb377 I used Mesh
before and received not enough ram, you can check my kaggle in here. My purpose is try use all 8 core of tpu and high RAM of it, but seem it only use one core with 16gb ram. Are you have any suggestions?
I don't seem to have access to the kaggle notebook, could you please share that or a minimal repro?
@jonb377 Sorry, my mistake. It was public now. You can re-check it here
@jonb377, do you have any ideas?
Thanks for sharing, I've had a look and made the following observations:
xmp.spawn
with SPMD, which isn't supported. We should call train
directly from the main process with xr.use_spmd()
.I wasn't able to run a repro, but could you try invoking train
directly instead of through xmp.spawn
?
if __name__ == '__main__':
print('Load')
train_dataloader, tokenizer, FLAGS = get_dataset()
xr.use_spmd()
# With SPMD, we should let one process control all devices. Multiprocessing isn't fully supported.
#xmp.spawn(train, args=(train_dataloader, tokenizer, FLAGS))
train(train_dataloader, tokenizer, FLAGS)
@jonb377 I changed the code to direct call train in __main__
, then this continues OOM as you said it is problem with dataloader. But I don't clearly understand; I only use about 5k-10k/100k samples (it is actually small, code is run with 1k samples), it seems data does not partition all tpu core and then gets an OOM error.
Hey @mxmm2123, sorry for the delay - I realize the dataloader may be a red herring since we just rely on the dataloader to call mark_step.
It would be helpful to determine if the OOM is happening before or after tracing the first step, you can either do an HLO dump with the environment variable XLA_FLAGS=--xla_dump_to=/tmp/xla_dump
or just add a print in the train loop:
for step, data in enumerate(xla_train_loader):
print(f'tracing {step=}')
input_ids, labels, attention_mask = data.input_ids, data.labels, data.attention_mask
optimizer.zero_grad()
...
Assuming the OOM is happening after tracing the first train step:
@jonb377 Thanks. As you recommended, it worked when I decreased the batch size to 32. Basically, the problem was solved after calling the train directly in the main and using a smaller batch size. Then, if xr.use_spmd()
will control all devices in one process, is it necessary when we use DistributedSampler
for it?
train_sampler = torch.utils.data.distributed.DistributedSampler(data_train,
num_replicas=xm.xrt_world_size(), # Equal 1 when use xr.use_spmd()
rank=xm.get_ordinal(), # 1
shuffle=True) #this guy is responsible for distributing data across 8 cores
rng = torch.Generator().manual_seed(42)
training_loader = torch.utils.data.DataLoader(data_train,
batch_size=FLAGS['BATCH_SIZE'],
collate_fn=DataCollatorWithPadding(tokenizer=tokenizer),
sampler=train_sampler,
drop_last=True, generator=rng)
sharding_spec = xs.ShardingSpec(mesh, (('dcn', 'data'), None))
xla_train_loader = pl.MpDeviceLoader(training_loader,
device = xm.xla_device(), # It equal 1
input_sharding=sharding_spec,
device_prefetch_size=4
)
Great! You can try increasing the batch size up to maybe 128. If you hit the runtime OOM again, you'll just need to decrease it.
With single-host SPMD, the main process will load the global batch and distribute it across all local devices, so you don't need the DistributedSampler. input_sharding
in MpDeviceLoader
is responsible for sharding the global batch across all devices.
@jonb377 Thank you, it's useful to me.
❓ Questions and Help
I recived error when try create sqmd mesh on kaggle notebook when flow Huggingface optimum-tpu
full error:
TPUv3-8 of kaggle have 8 cores(2x4) so I don't know why i get error. What problem? Thanks for your help!