pytorch / torchtune

PyTorch native finetuning library
https://pytorch.org/torchtune/main/
BSD 3-Clause "New" or "Revised" License
4.29k stars 430 forks source link

load_dataset fails on distributed recipes for datasets with remote code #1178

Open pbontrager opened 3 months ago

pbontrager commented 3 months ago

Description

When you load a dataset from HF with remote code, the load_dataset function prompts the user for permission to run remote code. This prompt only happens the first time the user downloads the dataset, but will cause a crash if the first time the user uses a dataset is with a distributed recipe. The likely causes a crash because load_dataset is called on all distributed processes, but the user is only prompted for permission on rank 0. The user will give permission but all the other ranks will hang waiting for a response which causes torch.distributed to crash.

Reproduce

This can be reproduced with cnn_dailymail_articles_dataset which runs remote code.

First ensure that you don't have the dataset cached (this will remove all cached datasets but it's the only way to guarantee you remove the cached files.

rm -r ~/.cache/huggingface/datasets
rm -r ~/.cache/huggingface/modules/datasets_modules/datasets

Then run any distributed recipe with this dataset. For example

tune run --nproc-per-node=2 lora_finetune_distributed --config llama2/7B_lora dataset._component_=torchtune.datasets.cnn_dailymail_articles_dataset

This should reach dataset initialization and then ask for permission to run remote code. Whichever response you provide will cause the recipe to crash since the processes will be out of sync after.

Possible Solutions

RdoubleA commented 3 months ago

Thanks for pointing this out @pbontrager. I've also found this is an issue with grammar_dataset. HF datasets that run remote code may be more prevalent than we thought, so we should prioritize this in the near future. I can look into it.

pbontrager commented 3 months ago

It might be worth checking what setting num_processes in load_dataset would do. Documentation