vwxyzjn / cleanba

CleanRL's implementation of DeepMind's Podracer Sebulba Architecture for Distributed DRL
Other
102 stars 11 forks source link

Install jax with CUDA #4

Closed ChufanSuki closed 1 year ago

ChufanSuki commented 1 year ago

Running poetry install doesn't install Jax with cuda for me. I test with this pyproject.toml. It sounds like installing jax with cuda support for me.

[tool.poetry]
name = "test"
version = "0.1.0"
description = ""
authors = ["Chufan Chen <allenplato28@gmail.com>"] 
readme = "README.md"

[tool.poetry.dependencies]
python = "^3.8"
jaxlib = {version =  "0.3.25+cuda11.cudnn82", source = "jax"}
jax = "0.3.25"

[[tool.poetry.source]]
name = "jax"
url = "https://storage.googleapis.com/jax-releases/jax_cuda_releases.html"
default = false
secondary = false

[build-system]
requires = ["poetry-core"]
build-backend = "poetry.core.masonry.api"

The pyproject.tom in cleanrl may also need modification.

vwxyzjn commented 1 year ago

Thanks for this issue. This was intentional because the user might want to install the TPU jax, or jax with a different cuda version. So we make the recommendation here

https://github.com/vwxyzjn/cleanba#installation

ChufanSuki commented 1 year ago

Sorry, I get so used to poetry install that I skip the installation part in the README.