Open antonioarbues opened 1 year ago
Did some progress on that. Turns out I did not have the gpu version of JAX installed and that was the source of the error. I would recommend mentioning it in the readme :)
Also, I would remove the device_config
field from the yaml files if they are not used as they are quite confusing.
It seems that the
device_config
parameters in the yaml files are not used anywhere. How can I train on GPU?If I try to set the GPU device in the JAX way as an env parameter with:
it gives me the following error: