AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.53k stars 293 forks source link

Training more than one epoch #914

Open peregilk opened 1 month ago

peregilk commented 1 month ago

@aireenmei Referring you here, because I think this issue is touched in #571 where you write:

I did not implement the auto restart because some users may not want their model to see repetitive data. I can add the multi-epoch support to our backlog. Meanwhile it should be straightforward to change the shard update logic here: https://github.com/google/maxtext/blob/main/MaxText/input_pipeline/_input_pipeline_utils.py#L105

The behaviour now seems to have changed a bit, and it might even be more confusing. I am a bit uncertain what has changed in the code here.

What I am trying to do is switching dataset during training. Here from step 160k. This is a fairly small special task dataset, and I am studying the effect. The dataset has 256 shards, and one epoch is roughly 350 steps.

Here is what is happening with comments:

# Perfectly normal. Switching to next shard. Weights and loss are fine
Updating host 3 dataset 0, was on shard 3
New shard is 67
completed step: 160086, seconds: 4.090, TFLOP/s/device: 116.004, Tokens/s/device: 2003.075, total_weights: 2080798, loss: 1.113

# Still normal
Updating host 3 dataset 0, was on shard 67
New shard is 131
completed step: 160177, seconds: 4.090, TFLOP/s/device: 115.995, Tokens/s/device: 2002.925, total_weights: 2078579, loss: 1.072

# Still normal
Updating host 3 dataset 0, was on shard 131
New shard is 195
completed step: 160268, seconds: 4.090, TFLOP/s/device: 115.989, Tokens/s/device: 2002.811, total_weights: 2079952, loss: 1.049

# Here things are starting to go south. The host starts generating all-0 paddings
completed step: 160359, seconds: 4.090, TFLOP/s/device: 116.001, Tokens/s/device: 2003.031, total_weights: 2077782, loss: 1.036

# Runs for a while, but then the total_weights start dropping, and the loss starts to drop
completed step: 160367, seconds: 4.091, TFLOP/s/device: 115.971, Tokens/s/device: 2002.507, total_weights: 2034296, loss: 1.030
completed step: 160368, seconds: 4.090, TFLOP/s/device: 116.002, Tokens/s/device: 2003.040, total_weights: 1860858, loss: 1.028
completed step: 160369, seconds: 4.090, TFLOP/s/device: 115.995, Tokens/s/device: 2002.928, total_weights: 1207504, loss: 1.038
completed step: 160370, seconds: 4.090, TFLOP/s/device: 115.991, Tokens/s/device: 2002.854, total_weights: 616193, loss: 1.038
completed step: 160371, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.734, total_weights: 184994, loss: 1.037
completed step: 160372, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.739, total_weights: 46490, loss: 1.058
completed step: 160373, seconds: 4.091, TFLOP/s/device: 115.976, Tokens/s/device: 2002.600, total_weights: 32596, loss: 0.989
completed step: 160374, seconds: 4.091, TFLOP/s/device: 115.978, Tokens/s/device: 2002.634, total_weights: 32491, loss: 1.041

# A bit later
completed step: 160460, seconds: 4.090, TFLOP/s/device: 115.987, Tokens/s/device: 2002.787, total_weights: 32673, loss: 0.980
completed step: 160461, seconds: 4.091, TFLOP/s/device: 115.970, Tokens/s/device: 2002.484, total_weights: 32503, loss: 1.043
completed step: 160462, seconds: 4.090, TFLOP/s/device: 115.984, Tokens/s/device: 2002.736, total_weights: 1904, loss: 1.068
completed step: 160463, seconds: 4.091, TFLOP/s/device: 115.966, Tokens/s/device: 2002.420, total_weights: 0, loss: 0.000
completed step: 160464, seconds: 4.090, TFLOP/s/device: 115.990, Tokens/s/device: 2002.845, total_weights: 0, loss: 0.000

This behaviour is a bit unpredictable. Especially since some shards here can be smaller, and it is hard to know when the first host runs out of shards. Running out of shards seems to hurt the model.

What is your advice here?

aireenmei commented 1 month ago

Hi @peregilk, the new behavior is documented here: https://github.com/AI-Hypercomputer/maxtext/blob/main/getting_started/Data_Input_Pipeline.md#huggingface-pipeline-in-multihost.

peregilk commented 1 month ago

@aireenmei Thanks a lot for the explanation. I thought the drop in weights and loss here did hurt the model, and was wondering why this did not show up in my evaluations. Now it makes total sense. Thanks.

peregilk commented 1 month ago

@aireenmei Just a couple of minor issues. Attaching to this thread since they are related. I followed the instructions on the page above, and discovered two minor issues:

aireenmei commented 1 month ago

Thanks for reporting. Yes setting eval_steps is recommended, it's no longer for debugging only. I'll update that.