Closed wang2yn84 closed 1 month ago
Tests can be run on CPU by setting JAX_PLATFORMS=cpu
. CI is doing this https://github.com/google/jetstream-pytorch/blob/main/.github/workflows/unit_tests.yaml#L79
Don't know if I'm the only one has this issue. Even with jax.config.update, the test still defaults on TPU.
It should run on CPU once setup with jax.config.update("jax_platform_name", "cpu"). I verified from my TPU host at that time.
Tests can be run on CPU by setting
JAX_PLATFORMS=cpu
. CI is doing this https://github.com/google/jetstream-pytorch/blob/main/.github/workflows/unit_tests.yaml#L79
Later I realized when I search the codebase. I'm fine to revert the PR if setting the env variable with the test is easier. But jax.config.update("jax_platform_name", "cpu")
doesn't work for me.
Don't know if I'm the only one has this issue. Even with jax.config.update, the test still defaults on TPU.
It should run on CPU once setup with jax.config.update("jax_platform_name", "cpu"). I verified from my TPU host at that time.
I did the following experiment:
import jax jax.default_backend() 'tpu' jax.config.update("jax_platform_name", "cpu") jax.default_backend() 'tpu'
It doesn't mean it's not working. Actually see another try:
import jax jax.config.update("jax_platform_name", "cpu") jax.default_backend() 'cpu'
Seems it caches the previous results and after that you call jax.config.update doesn't work any more. Setting the env var is a more stable solution.
Don't know if I'm the only one has this issue. Even with jax.config.update, the test still defaults on TPU.