Open jon-hotaisle opened 1 month ago
@anthonix bump.
Will try and reproduce -- on the list of things to do when I have some spare cycles
Oh, I'm blind (and probably dumb). val_loss must be 0, hence the nan. so it must be something in that gpt2_validate() returning all zeros.
In the mean time, can you verify some other training works, like AMD's tinyllama code they recently released? Or their JAX GPT2 training?
Just doing a bit of debugging.
"val loss" output nan, so I figured start there...
But digging higher up, val_num_batches is set to 20, so I'm not sure how this is turning into nan so easily. Feels like something else is up...