huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
132.45k stars 26.38k forks source link

Training of GPT2 hang during Checkpoint stage #28662

Closed jchauhan closed 6 months ago

jchauhan commented 8 months ago

System Info

Env

- `transformers` version: 4.38.0.dev0
- Platform: Linux-5.4.0-1043-gcp-x86_64-with-glibc2.31
- Python version: 3.10.0
- Huggingface_hub version: 0.20.3
- Safetensors version: 0.4.2
- Accelerate version: 0.26.1
- Accelerate config:    not found
- PyTorch version (GPU?): 2.1.2+cu121 (False)
- Tensorflow version (GPU?): not installed (NA)
- Flax version (CPU?/GPU?/TPU?): not installed (NA)
- Jax version: not installed
- JaxLib version: not installed
- Using GPU in script?: TPU
- Using distributed or parallel set-up in script?: xla_spwn script

GCP TPU v2.8 Architecture

Libraries installed

absl-py                  2.1.0
accelerate               0.26.1
aiohttp                  3.9.1
aiosignal                1.3.1
annotated-types          0.6.0
asttokens                2.4.1
async-timeout            4.0.3
attrs                    23.2.0
bitsandbytes             0.42.0
cachetools               5.3.2
certifi                  2023.11.17
charset-normalizer       3.3.2
cloud-tpu-client         0.10
datasets                 2.16.1
decorator                5.1.1
deepspeed                0.13.0
dill                     0.3.7
evaluate                 0.4.1
exceptiongroup           1.2.0
executing                2.0.1
filelock                 3.13.1
frozenlist               1.4.1
fsspec                   2023.10.0
google-api-core          1.34.0
google-api-python-client 1.8.0
google-auth              2.26.2
google-auth-httplib2     0.2.0
googleapis-common-protos 1.62.0
hjson                    3.1.0
httplib2                 0.22.0
huggingface-hub          0.20.3
idna                     3.6
install                  1.3.5
ipython                  8.20.0
jedi                     0.19.1
Jinja2                   3.1.3
joblib                   1.3.2
libtpu-nightly           0.1.dev20230825+default
loralib                  0.1.2
MarkupSafe               2.1.4
matplotlib-inline        0.1.6
mpmath                   1.3.0
multidict                6.0.4
multiprocess             0.70.15
networkx                 3.2.1
ninja                    1.11.1.1
numpy                    1.26.3
nvidia-cublas-cu12       12.1.3.1
nvidia-cuda-cupti-cu12   12.1.105
nvidia-cuda-nvrtc-cu12   12.1.105
nvidia-cuda-runtime-cu12 12.1.105
nvidia-cudnn-cu12        8.9.2.26
nvidia-cufft-cu12        11.0.2.54
nvidia-curand-cu12       10.3.2.106
nvidia-cusolver-cu12     11.4.5.107
nvidia-cusparse-cu12     12.1.0.106
nvidia-nccl-cu12         2.18.1
nvidia-nvjitlink-cu12    12.3.101
nvidia-nvtx-cu12         12.1.105
oauth2client             4.1.3
packaging                23.2
pandas                   2.2.0
parso                    0.8.3
peft                     0.7.2.dev0
pexpect                  4.9.0
pillow                   10.2.0
pip                      21.2.3
prompt-toolkit           3.0.43
protobuf                 3.20.3
psutil                   5.9.8
ptyprocess               0.7.0
pure-eval                0.2.2
py-cpuinfo               9.0.0
pyarrow                  15.0.0
pyarrow-hotfix           0.6
pyasn1                   0.5.1
pyasn1-modules           0.3.0
pydantic                 2.5.3
pydantic_core            2.14.6
Pygments                 2.17.2
pynvml                   11.5.0
pyparsing                3.1.1
python-dateutil          2.8.2
pytz                     2023.3.post1
PyYAML                   6.0.1
regex                    2023.12.25
requests                 2.31.0
responses                0.18.0
rsa                      4.9
safetensors              0.4.2
scikit-learn             1.4.0
scipy                    1.12.0
setuptools               57.4.0
six                      1.16.0
sklearn                  0.0
stack-data               0.6.3
sympy                    1.12
threadpoolctl            3.2.0
tokenizers               0.15.1
torch                    2.1.2
torch-xla                2.1.0
torchvision              0.16.2
tqdm                     4.66.1
traitlets                5.14.1
transformers             4.38.0.dev0
triton                   2.1.0
typing_extensions        4.9.0
tzdata                   2023.4
uritemplate              3.0.1
urllib3                  2.1.0
wcwidth                  0.2.13
xxhash                   3.4.1
yarl                     1.9.4

Command

Who can help?

text models: @ArthurZucker and @younesbelkada trainer: @muellerzr and @pacman100

Information

Tasks

Reproduction

  1. Procure a GCP TPU v2.8 VM
  2. Setup Transformer in a virtual env
  3. run the training command similar to below
export PJRT_DEVICE=TPU
python ./transformers/examples/pytorch/xla_spawn.py --num_cores 8  ./transformers/examples/pytorch/language-modeling/run_clm.py --model_name_or_path "gpt2" \
    --train_file data.txt \
    --per_device_train_batch_size 2 \
    --per_device_eval_batch_size 2 \
    --do_train \
    --output_dir my-gpt \
    --overwrite_output_dir \
    --log_level debug \
    --save_steps 1000 \
    --cache_dir ./cache/ \
    --num_train_epochs 40

Expected behavior

The trained model and checkpoint should be complete within a reasonable time of 15 mins. The training takes 5 mins however, checkpointing and saving model does not complete

ArthurZucker commented 8 months ago

Would recommend you to check this #26724 and try the solution, might be that or if the saving does not work, concurrency there. Code was recently changed cc @muellerzr 🤗

github-actions[bot] commented 7 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.