AI-Hypercomputer / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
42 stars 15 forks source link

Make jpt the default cli - remove other entry point scripts #188

Closed qihqi closed 1 month ago

qihqi commented 1 month ago

For multiple host inference, looks like we need to download model manually. Correct me If I'm wrong.

For multihost: either each host download their own copy; then discard the part that they dont need. OR, only the host number 0 downloads, then jax.device_put the shards to the other hosts. Both scheme should work but will be implemented separately.

FanhaiLu1 commented 1 month ago

For multiple host inference, looks like we need to download model manually. Correct me If I'm wrong.

For multihost: either each host download their own copy; then discard the part that they dont need. OR, only the host number 0 downloads, then jax.device_put the shards to the other hosts. Both scheme should work but will be implemented separately.

Thanks clarifying!