google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
21 stars 12 forks source link

Fixes tests. Can now run on CPU by default. #95

Closed wang2yn84 closed 1 month ago

wang2yn84 commented 1 month ago

Don't know if I'm the only one has this issue. Even with jax.config.update, the test still defaults on TPU.

lsy323 commented 1 month ago

Tests can be run on CPU by setting JAX_PLATFORMS=cpu. CI is doing this https://github.com/google/jetstream-pytorch/blob/main/.github/workflows/unit_tests.yaml#L79

FanhaiLu1 commented 1 month ago

Don't know if I'm the only one has this issue. Even with jax.config.update, the test still defaults on TPU.

It should run on CPU once setup with jax.config.update("jax_platform_name", "cpu"). I verified from my TPU host at that time.

wang2yn84 commented 1 month ago

Tests can be run on CPU by setting JAX_PLATFORMS=cpu. CI is doing this https://github.com/google/jetstream-pytorch/blob/main/.github/workflows/unit_tests.yaml#L79

Later I realized when I search the codebase. I'm fine to revert the PR if setting the env variable with the test is easier. But jax.config.update("jax_platform_name", "cpu") doesn't work for me.

wang2yn84 commented 1 month ago

Don't know if I'm the only one has this issue. Even with jax.config.update, the test still defaults on TPU.

It should run on CPU once setup with jax.config.update("jax_platform_name", "cpu"). I verified from my TPU host at that time.

I did the following experiment:

import jax jax.default_backend() 'tpu' jax.config.update("jax_platform_name", "cpu") jax.default_backend() 'tpu'

It doesn't mean it's not working. Actually see another try:

import jax jax.config.update("jax_platform_name", "cpu") jax.default_backend() 'cpu'

Seems it caches the previous results and after that you call jax.config.update doesn't work any more. Setting the env var is a more stable solution.