Closed phlippe closed 6 months ago
https://github.com/jwpartyka/uva-fomo/blob/ae72f8b29a44be313a124273fdd9450a5cc55632/trainer.py#L141
Recommendation to get all metrics together from the device one (i.e. metrics = jax.device_get(metrics)) instead of one-by-one. This could otherwise stall the process a bit.
metrics = jax.device_get(metrics)
Thanks! It should be fine now.
https://github.com/jwpartyka/uva-fomo/blob/ae72f8b29a44be313a124273fdd9450a5cc55632/trainer.py#L141
Recommendation to get all metrics together from the device one (i.e.
metrics = jax.device_get(metrics)
) instead of one-by-one. This could otherwise stall the process a bit.