kingoflolz / mesh-transformer-jax

Model parallel transformers in JAX and Haiku
Apache License 2.0
6.26k stars 890 forks source link

training stuck at validation step 1 #218

Open Selimonder opened 2 years ago

Selimonder commented 2 years ago

Hello,

First of all, thank you for the great finetune guide. I followed this guide through and attempted fine-tuning GPT-J with 30-40MBs of a small dataset.

However, I am stuck at device_train.py step (which is 12th step).

The compiling of the train, eval, and network passes. Also, the first weights are written on the bucket.

it seems like the code is freezing at

out = network.eval(inputs) line under eval_step function. During compiling, data is passing through eval_step but when actual training starts it freezes.

Did anyone stumble upon a similar issue?

versae commented 1 year ago

Hi! I've trained GPT-J models in the past, but for some reason I'm now seeing this too. Did you manage to solve it?

rinapch commented 1 year ago

Hey @versae, @Selimonder! Did you manage somehow to resolve this issue? Facing the same problem rn

versae commented 1 year ago

@rinapch for me the key was to select the alpha version when creating the TPU. Stable releases seem to break the implementation, not sure why.

rinapch commented 1 year ago

Thanks a lot, @versae! It really did help 🥳