RobertTLange / gymnax-blines

Baselines for gymnax 🤖
Apache License 2.0
57 stars 13 forks source link

device_config not used? #10

Open antonioarbues opened 1 year ago

antonioarbues commented 1 year ago

It seems that the device_config parameters in the yaml files are not used anywhere. How can I train on GPU?

If I try to set the GPU device in the JAX way as an env parameter with:

JAX_PLATFORMS="cuda" python train.py -config agents/CartPole-v1/ppo.yaml

it gives me the following error:

RuntimeError: Unable to initialize backend 'cuda': module 'jaxlib.xla_extension' has no attribute 'GpuAllocatorConfig' (set JAX_PLATFORMS='' to automatically choose an available backend)
antonioarbues commented 1 year ago

Did some progress on that. Turns out I did not have the gpu version of JAX installed and that was the source of the error. I would recommend mentioning it in the readme :)

Also, I would remove the device_config field from the yaml files if they are not used as they are quite confusing.