Closed qihqi closed 1 month ago
For multiple host inference, looks like we need to download model manually. Correct me If I'm wrong.
For multihost: either each host download their own copy; then discard the part that they dont need. OR, only the host number 0 downloads, then
jax.device_put
the shards to the other hosts. Both scheme should work but will be implemented separately.
Thanks clarifying!
For multihost: either each host download their own copy; then discard the part that they dont need. OR, only the host number 0 downloads, then
jax.device_put
the shards to the other hosts. Both scheme should work but will be implemented separately.