Open dumitrescustefan opened 2 years ago
Pinging @patil-suraj :)
Update: after reading this issue I tried setting the number of preprocessing workers to 1, and after a lot of time, preprocessing finished without any crashes. So that 'solves' problem 1.
However, problem 2 still shows up. At least it's not related to problem 1.
Here is the error:
Training...: 5%|▌ | 2319/43228 [1:05:47<19:18:26, 1.70s/it].[A/home/stefan/transformers/examples/flax/language-modeling/run_clm_flax.py:202: VisibleDeprecationWarning: Creating an ndarray from ragged nested sequences (which is a lis
t-or-tuple of lists-or-tuples-or ndarrays with different lengths or shapes) is deprecated. If you meant to do this, you must specify 'dtype=object' when creating the ndarray.
batch = {k: np.array(v) for k, v in batch.items()}
.[A
Epoch ... : 0%| | 0/100 [1:05:47<?, ?it/s]
Traceback (most recent call last):
File "/home/stefan/transformers/examples/flax/language-modeling/run_clm_flax.py", line 677, in <module>
main()
File "/home/stefan/transformers/examples/flax/language-modeling/run_clm_flax.py", line 618, in main
state, train_metric = p_train_step(state, batch)
File "/home/stefan/dev/lib/python3.8/site-packages/jax/_src/traceback_util.py", line 162, in reraise_with_filtered_traceback
return fun(*args, **kwargs)
File "/home/stefan/dev/lib/python3.8/site-packages/jax/_src/api.py", line 1946, in cache_miss
out_tree, out_flat = f_pmapped_(*args, **kwargs)
File "/home/stefan/dev/lib/python3.8/site-packages/jax/_src/api.py", line 1801, in f_pmapped
for arg in args: _check_arg(arg)
File "/home/stefan/dev/lib/python3.8/site-packages/jax/_src/api.py", line 2687, in _check_arg
raise TypeError(f"Argument '{arg}' of type {type(arg)} is not a valid JAX type.")
jax._src.traceback_util.UnfilteredStackTrace: TypeError: Argument '[[list([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1,
1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1 [etc - many lists of 1s]
Any idea what might cause this? Thanks!
Update: setting preprocessing_num_workers = 1
seems to solve the core dump. It takes a long time to preprocess and load the dataset on the 96 core machine with 1 worker :) but I'm not seeing the dumps anymore.
Regarding the Argument is not a valid Jax type, I am not sure whether I "fixed" the problem, but now I've managed to train an epoch without crashing. What I did was set truncation=True
in run_clm_flax.py
. This costs me some lost text when the line is longer than the model's max len, but hey, it's running. I'm not very affected by this as GPT-Neo has a 2048 len, but I'm thinking if I had to train a model with the standard 512 size, a lot of text would have been needlessly lost if it was not split manually beforehand to avoid this error. Again, this is strange because the code seems to chunk the tokenized text in seq_len blocks, so this shouldn't be a problem, but setting truncation=True in the tokenizer seems to fix it. Also, this is not related to the core dumps, as after setting workers = 1, the Jax error still happened until I set to truncate the texts.
So, I kind-of "fixed" my problems, please close this issue if you think it's not helpful. Leaving this here for other people if they bump into these things.
This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.
Please note that issues that do not follow the contributing guidelines are likely to be ignored.
I've re-opened that issue, because I've seen this problem over a long time. The reported error still occurs in latest Datasets and Transformers version on TPU.
I also had the same issue with another dataset and t5 model training. This problem seems to be related to datasets because I cut out the code of t5 training except for the data generation part, and I had the same "SIGTERM" error on TPU V4 VM.
I have tested it with Python 3.8 and python 3.7, and the same error occurs.
@stefan-it @dumitrescustefan did you find a solution rather than setting preprocessing_num_workers to 1 because it is extremely slow?
@patil-suraj Is there any solution to this problem?
I think this may have to do with jax, libtpu, torch, xla or tf not matching up. Versions of the VM used (f.e. v2-alpha), and jax, jaxlib, and if torch and xla were changed. And tpu driver version.
Hi, I'm having a weird problem trying to train a gpt-neo model from scratch on a v3-8 cloud TPU. Something similar to the closed issue here. Getting:
randomly during preprocessing/loading the dataset.
The env is clean, setup according to the Quickstart Flax guide from google's help page, and as well from here. Jax is installed okay, sees 8 TPUs. I tried the standard pip install as well as the local install as some people suggested in the issue above, still getting the same behavior.
This error does not kill the training. So, question number 1 would be how to get rid of this error ?
Something else happens that might be related: Running a dummy 300MB Wiki dataset for training only produces the error above, but training progresses. However, when running the full 40GB dataset, at a point during the first epoch I get:
list([1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, .... (many 1s) .. 1, 1, 1])]]' of type <class 'numpy.ndarray'> is not a valid JAX type.
This error kills the training. I've found this related issue, but the last suggestion of increasing
max_seq_len
does not apply here, as the preprocessing should automatically concatenate and cut the model len (and it is set in the config file). The dataset itself is clean, does not contain long words or chars or anything weird.Thus, question 2: Any pointers on how to solve this second error?
Unfortunately I cannot share the dataset as it's private :disappointed: so I don't know how to help reproduce this error. There are 2 questions in this single issue as maybe there's a chance they are related (?).
Thanks a bunch!
Update: here is the output of the run_clm_flax.py. Because there's a limit on how much you can paste online, I've deleted a few chunks of repeating lines in the output.