Open YUE-FAN opened 3 months ago
+1
My current workaround is to catch the FileNotFoundError and recreate the data_iterator. To get back to the batch you were on, you can then run load_next_batch multiple times:
def create_skipped_iterator(config, mesh, step):
while True:
print(f'Starting skipping: step={step}, time={datetime.datetime.now()}')
try:
data_iterator, _ = create_data_iterator(config, mesh)
for _ in range(step):
_ = load_next_batch(data_iterator, None, config)
break
except FileNotFoundError:
print("Encountered FileNotFoundError during skipping :(, will create a new data_iterator")
continue
print(f'Finished skipping: step={step}, time={datetime.datetime.now()}')
return data_iterator
but that's a very bad, error-prone, slow solution.
My current workaround is to catch the FileNotFoundError and recreate the data_iterator. To get back to the batch you were on, you can then run load_next_batch multiple times:
def create_skipped_iterator(config, mesh, step): while True: print(f'Starting skipping: step={step}, time={datetime.datetime.now()}') try: data_iterator, _ = create_data_iterator(config, mesh) for _ in range(step): _ = load_next_batch(data_iterator, None, config) break except FileNotFoundError: print("Encountered FileNotFoundError during skipping :(, will create a new data_iterator") continue print(f'Finished skipping: step={step}, time={datetime.datetime.now()}') return data_iterator
but that's a very bad, error-prone, slow solution.
That's a nice workaround :) thanks a lot!
I made a small modification so that you don't have to recreate the data iterator every time and skip all the data from the beginning. If we checkpoint the data iterator using Grain, your function can also be quite efficient with the following modification. This code at the moment works for me :)
def create_skipped_iterator(config, mesh, step):
while True:
print(f'>>> Starting skipping: step={step}, time={datetime.datetime.now()}')
try:
start_step = int(step // config.checkpoint_period * config.checkpoint_period)
storage_client = storage.Client()
bucket = storage_client.get_bucket(config.gcp_bucket)
path = f'{config.run_name}/checkpoints/{start_step}/iter/process_{int(jax.process_index())}-of-{int(jax.process_count())}.json'
print(f'>>> load dataloader state from path {path}')
blob = bucket.blob(path)
state = json.loads(blob.download_as_string(client=None))
state = json.dumps(state, indent=4).encode()
storage_client.close()
data_iterator, _ = create_data_iterator(config, mesh)
data_iterator.local_iterator.set_state(state) # see https://github.com/google/grain/blob/main/grain/_src/python/data_loader_test.py#L496C28-L496C77
print(f'>>> skipping from step={start_step + 1} to step={step - 1}')
for _ in range(start_step + 1, step):
_ = load_next_batch(data_iterator, None, config)
break
except FileNotFoundError:
print(">>> Encountered FileNotFoundError during skipping :(, will create a new data_iterator")
continue
print(f'>>> Finished skipping: step={step}, time={datetime.datetime.now()}')
return data_iterator
PS: it just occurrs to me when writing this reply that maybe a better workaround is to keep a copy of the data iterator's state from the previous step inside the train loop (state = data_iterator.local_iterator.get_state()
). In case of FileNotFoundError
, we just repeat this single step from the state
until it passes:
example_batch = load_next_batch(data_iterator, example_batch, config)
->
try:
state = data_iterator.local_iterator.get_state()
example_batch = load_next_batch(data_iterator, example_batch, config)
except FileNotFoundError:
while True:
try:
data_iterator.local_iterator.set_state(state)
example_batch = load_next_batch(data_iterator, None, config)
break
except FileNotFoundError:
data_iterator.local_iterator.set_state(state)
continue
Though I have tested this code yet...
Hi,
I was testing the multi-host training on a v4-16 TPU VM. The training normally runs smoothly, but sometimes, it collapses at
load_next_batch
with the following error from the process 0:The command for running the job is
python3 MaxText/train.py MaxText/configs/gpt2.yml run_name=gpt2 base_output_directory=gs://maxtext_multihost_job steps=120000 dataset_type=hf hf_path=YUE-FAN/openwebtext_gcp hf_data_dir=data tokenizer_path=EleutherAI/gpt-neox-20b eval_interval=4000 hf_eval_split=validation enable_checkpointing=True eval_batch_num=558 per_device_batch_size=32 eval_per_device_batch_size=32 checkpoint_period=10000 logits_via_embedding=True normalize_embedding_logits=True
. I have very limited knowledge about Python multiprocessing, but it seems to be a problem related to reading the shared memory? This problem does not always occur, but it happens from time to time. Any assistance here will be appreciated! Thanks!