Lightning-AI / pytorch-lightning

Pretrain, finetune and deploy AI models on multiple GPUs, TPUs with zero code changes.
https://lightning.ai
Apache License 2.0
27.95k stars 3.34k forks source link

TPU v3-8 deadlocks when using datasets larger than 2^15 on 8 devices #18176

Open SebastianLoef opened 1 year ago

SebastianLoef commented 1 year ago

Bug description

Our tpu v3-8 deadlocks when using multiple 8 TPU cores on large datasets. Specifically, datasets larger than 2^15; one size larger and we get deadlock.

The deadlock occurs somewhere between somewhere between line 222 and line 235 in fit_loop.py. Works fine when using only a single core.

Using lightning 2.0.6 | tpu-vm-pt-2.0

What version are you seeing the problem on?

v2.0

How to reproduce the bug

train_dataloader = DataLoader(
        train_dataset,
        batch_size=args.batch_size,
        shuffle=True,
        num_workers=args.num_workers,
        #pin_memory=True,
        #persistent_workers=True,
        drop_last=True,
    )
...

trainer = L.Trainer(
        #callbacks=checkpoint_callbacks,
        #logger=wandb_logger,
        max_epochs=args.epochs,
        accelerator=args.accelerator,
        devices=args.devices,
        precision=args.precision,
        num_sanity_val_steps=0,
        log_every_n_steps=10,
        check_val_every_n_epoch=args.check_val_every_n_epoch,
        #profiler="xla",
        strategy="xla",
    )
    trainer.fit(
        model,
        train_dataloaders=train_dataloader,
        #val_dataloaders=val_dataloader,
    )

Error messages and logs

E0727 11:55:45.723604 1727432 coredump_hook.cc:414] RAW: Remote crash data gathering hook invoked. E0727 11:55:45.723633 1727432 coredump_hook.cc:453] RAW: Skipping coredump since rlimit was 0 at process start. E0727 11:55:45.723659 1727432 client.cc:278] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec. E0727 11:55:45.723670 1727432 coredump_hook.cc:512] RAW: Sending fingerprint to remote end. E0727 11:55:45.723677 1727432 coredump_socket.cc:120] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket E0727 11:55:45.723692 1727432 coredump_hook.cc:518] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running? E0727 11:55:45.723700 1727432 coredump_hook.cc:580] RAW: Dumping core locally. E0727 11:55:46.275513 1727432 process_state.cc:784] RAW: Raising signal 11 with default behavior

Environment

No response

More info

No response

cc @carmocca @JackCaoG @Liyang90 @gkroiz @steventk-g

awaelchli commented 1 year ago

Hi @SebastianLoef

Is this an in-memory dataset? This might be related too #18131, but I'm not sure.

SebastianLoef commented 1 year ago

@awaelchli No, our dataset regard mp3 files that are loaded using torchaudio.load in an online fashion.

We tried using two different datasets, both of which fail at the same size 2^15 + 1

EDIT: The dataset file locations are stored in a pandas dataframe, if that's useful.

awaelchli commented 1 year ago

The dataset file locations are stored in a pandas dataframe, if that's useful.

Does your process get stuck at the fit() call, when processes are being forked? A quick random guess from me: Try moving your dataloader to the LightningModule.train_dataloader() hooks instead of passing them into the trainer.fit method. If this works, then this could hint at a memory-sharing issue of the data frame when forking processes.

SebastianLoef commented 1 year ago

Your quick random guess worked! It doesn't deadlock this way, though, using this solution each core has its own complete copy of the dataset, therefore 1 epoch actually becomes the equivalent of 8 epochs.

awaelchli commented 1 year ago

Your quick random guess worked! It doesn't deadlock this way

Ok, that's good to know. It looks like you hit some limit with memory access in forked processes. You might want to open an issue on the xla repo about this, but my suspicion is that this could even be just a basic limitation in the multiprocessing package.

using this solution each core has its own complete copy of the dataset

I think that would also be the case if you passed your data the way you had before. One possible explanation for this is that you are working with an iterable-style dataset perhaps? If yes: For these types of datasets, Lightning can't automatically take care of distributed sampling (happy to give more instructions if this is the case). If no: we have to look into this case further.

SebastianLoef commented 1 year ago

If no: we have to look into this case further.

It's a no, we use map-style dataset... But it might be the case that we're interpreting the progressbar wrong. If I remember correctly it, by passing dataloader into the fit function, the number steps displayed in the progressbar is len(dataset) / (batch_size * devices), and with this solution it shows len(dataset) / batch_size number of steps. But I may be remembering incorrectly,I'll double check tonight.

SebastianLoef commented 1 year ago

I was wrong, I apologize. The number of steps in the progressbar reads the same for both solutions.

Continuing this thread thought, we ran into another problem that is possibly related. Using 8 cores does not work when using pretrained weights from torchvision. As per usual, works with 1 cores.

from torchvision.models import resnet50, ResNet50_Weights
backbone = resnet50(weights=ResNet50_Weights.IMAGENET1K_V2)

Works fine when we remove the weights argument

SebastianLoef commented 1 year ago

Update.

This is not only happening for torchvision models. Loading any model state_dict breaks using 8 cores.