huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.78k stars 26.74k forks source link

FLAX core dump error on CloudTPU when running run_clm_flax.py #14497

Open dumitrescustefan opened 2 years ago

dumitrescustefan commented 2 years ago

Hi, I'm having a weird problem trying to train a gpt-neo model from scratch on a v3-8 cloud TPU. Something similar to the closed issue here. Getting:

https://symbolize.stripped_domain/r/?trace=7fb5dbf8a3f4,7fb5dbfe020f,7f&map= 
*** SIGTERM received by PID 64823 (TID 64823) on cpu 26 from PID 63364; stack trace: ***                                                                                                                                     | 0/1 [00:00<?, ?ba/s]
PC: @     0x7fb5dbf8a3f4  (unknown)  do_futex_wait.constprop.0
    @     0x7fb52fa377ed        976  (unknown)
    @     0x7fb5dbfe0210  440138896  (unknown)                                                                                                                                                                               | 0/1 [00:00<?, ?ba/s]
    @               0x80  (unknown)  (unknown)                                                                                                                                                                               | 0/1 [00:00<?, ?ba/s]
https://symbolize.stripped_domain/r/?trace=7fb5dbf8a3f4,7fb52fa377ec,7fb5dbfe020f,7f&map=44c8b163be936ec2996e56972aa94d48:7fb521e7d000-7fb52fd90330 
E1122 14:13:36.933620   64823 coredump_hook.cc:255] RAW: Remote crash gathering disabled for SIGTERM.                                                                                                                        | 0/1 [00:00<?, ?ba/s]
E1122 14:13:36.960024   64823 process_state.cc:776] RAW: Raising signal 15 with default behavior

randomly during preprocessing/loading the dataset.

The env is clean, setup according to the Quickstart Flax guide from google's help page, and as well from here. Jax is installed okay, sees 8 TPUs. I tried the standard pip install as well as the local install as some people suggested in the issue above, still getting the same behavior.

This error does not kill the training. So, question number 1 would be how to get rid of this error ?

Something else happens that might be related: Running a dummy 300MB Wiki dataset for training only produces the error above, but training progresses. However, when running the full 40GB dataset, at a point during the first epoch I get:

list([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, .... (many 1s) .. 1, 1, 1])]]' of type <class 'numpy.ndarray'> is not a valid JAX type.

This error kills the training. I've found this related issue, but the last suggestion of increasing max_seq_len does not apply here, as the preprocessing should automatically concatenate and cut the model len (and it is set in the config file). The dataset itself is clean, does not contain long words or chars or anything weird.

Thus, question 2: Any pointers on how to solve this second error?

Unfortunately I cannot share the dataset as it's private :disappointed: so I don't know how to help reproduce this error. There are 2 questions in this single issue as maybe there's a chance they are related (?).

Thanks a bunch!

Update: here is the output of the run_clm_flax.py. Because there's a limit on how much you can paste online, I've deleted a few chunks of repeating lines in the output.

LysandreJik commented 2 years ago

Pinging @patil-suraj :)

dumitrescustefan commented 2 years ago

Update: after reading this issue I tried setting the number of preprocessing workers to 1, and after a lot of time, preprocessing finished without any crashes. So that 'solves' problem 1.

However, problem 2 still shows up. At least it's not related to problem 1.

Here is the error:

Training...:   5%|▌         | 2319/43228 [1:05:47<19:18:26,  1.70s/it].[A/home/stefan/transformers/examples/flax/language-modeling/run_clm_flax.py:202: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a lis
t-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.                                                                     
  batch = {k: np.array(v) for k, v in batch.items()}                                                                                                                                                                                               

                                                                      .[A                                                                                                                                                                          
Epoch ... :   0%|          | 0/100 [1:05:47<?, ?it/s]                                                                                                                                                                                              
Traceback (most recent call last):                                                                                                                                                                                                                 
  File "/home/stefan/transformers/examples/flax/language-modeling/run_clm_flax.py", line 677, in <module>                                                                                                                                          
    main()                                                                                                                                                                                                                                         
  File "/home/stefan/transformers/examples/flax/language-modeling/run_clm_flax.py", line 618, in main                                                                                                                                              
    state, train_metric = p_train_step(state, batch)                                                                                                                                                                                               
  File "/home/stefan/dev/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback                                                                                                                     
    return fun(*args, **kwargs)                                                                                                                                                                                                                    
  File "/home/stefan/dev/lib/python3.8/site-packages/jax/_src/api.py", line 1946, in cache_miss                                                                                                                                                    
    out_tree, out_flat = f_pmapped_(*args, **kwargs)                                                                                                                                                                                               
  File "/home/stefan/dev/lib/python3.8/site-packages/jax/_src/api.py", line 1801, in f_pmapped                                                                                                                                                     
    for arg in args: _check_arg(arg)                                                                                                                                                                                                               
  File "/home/stefan/dev/lib/python3.8/site-packages/jax/_src/api.py", line 2687, in _check_arg                                                                                                                                                    
    raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")                                                                                                                                                              
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Argument '[[list([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 [etc - many lists of 1s]

Any idea what might cause this? Thanks!

dumitrescustefan commented 2 years ago

Update: setting preprocessing_num_workers = 1 seems to solve the core dump. It takes a long time to preprocess and load the dataset on the 96 core machine with 1 worker :) but I'm not seeing the dumps anymore.

Regarding the Argument is not a valid Jax type, I am not sure whether I "fixed" the problem, but now I've managed to train an epoch without crashing. What I did was set truncation=True in run_clm_flax.py. This costs me some lost text when the line is longer than the model's max len, but hey, it's running. I'm not very affected by this as GPT-Neo has a 2048 len, but I'm thinking if I had to train a model with the standard 512 size, a lot of text would have been needlessly lost if it was not split manually beforehand to avoid this error. Again, this is strange because the code seems to chunk the tokenized text in seq_len blocks, so this shouldn't be a problem, but setting truncation=True in the tokenizer seems to fix it. Also, this is not related to the core dumps, as after setting workers = 1, the Jax error still happened until I set to truncate the texts.

So, I kind-of "fixed" my problems, please close this issue if you think it's not helpful. Leaving this here for other people if they bump into these things.

github-actions[bot] commented 2 years ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.

stefan-it commented 2 years ago

I've re-opened that issue, because I've seen this problem over a long time. The reported error still occurs in latest Datasets and Transformers version on TPU.

agemagician commented 2 years ago

I also had the same issue with another dataset and t5 model training. This problem seems to be related to datasets because I cut out the code of t5 training except for the data generation part, and I had the same "SIGTERM" error on TPU V4 VM.

I have tested it with Python 3.8 and python 3.7, and the same error occurs.

@stefan-it @dumitrescustefan did you find a solution rather than setting preprocessing_num_workers to 1 because it is extremely slow?

@patil-suraj Is there any solution to this problem?

Ontopic commented 2 years ago

I think this may have to do with jax, libtpu, torch, xla or tf not matching up. Versions of the VM used (f.e. v2-alpha), and jax, jaxlib, and if torch and xla were changed. And tpu driver version.