erfanzar / EasyDeL

Accelerate your training with this open-source library. Optimize performance with streamlined training and serving options with JAX. 🚀
https://easydel.readthedocs.io/en/latest/
Apache License 2.0
167 stars 19 forks source link

Add gradient norm logging, fix metric collection on multi-worker setup #135

Closed yhavinga closed 2 months ago

yhavinga commented 3 months ago

These metrics provide can help identify potential issues for improvement of optimizer settings.

For review. Perhaps get_layer_names() is better placed in FJFormer. On the other hand, right now its close by to e.g. fiddle a bit with how layer names are represented in wandb. Also the return format of the train loss function has changed to return state, loss__, metrics, with accuracy being part of the returned metrics dict. But eval loss still returns loss, accuracy. I'm not sure that is too bad, since the loss functions of train and eval are two different ones anyway.