AI-Hypercomputer / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
41 stars 15 forks source link

Enable jax compilation flags for jpt #199

Closed vivianrwu closed 3 weeks ago

vivianrwu commented 3 weeks ago
args:
        - --model_id=google/gemma-7b-it
        - --override_batch_size=32
        - --enable_model_warmup=True
        - --internal_jax_compilation_cache_dir=gs://vivianrwu-jetstream-ckpts/pytorch/jax_cache

Validated that the cache ends up in the gsbucket