google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
32 stars 14 forks source link

Add ray multiple host support #63

Closed FanhaiLu1 closed 4 months ago

FanhaiLu1 commented 4 months ago

This PR enable pytorch engine multiple host on TPU POD slices.

MVP Goal:

The current PR is MVP version of multiple host, two goals in MPV:

  1. Load weight and sharding in multiple host

  2. Compute meaningful decode result

Result validation:

Weight and sharding

With this PR, the weight and sharding on v5e-16 (4 host , total 16 chips) use 50 of memory compared with 8 chips. It worked as expected.

memory using 804.8 MiB / 15.7 GiB (4.990588%) on TPU_10(process=3,(2,2,0,0))
memory using 804.8 MiB / 15.7 GiB (4.990588%) on TPU_11(process=3,(3,2,0,0))
memory using 804.8 MiB / 15.7 GiB (4.990588%) on TPU_14(process=3,(2,3,0,0))
memory using 804.8 MiB / 15.7 GiB (4.990588%) on TPU_15(process=3,(3,3,0,0))

The weight and sharding on v5e-8 (4 host , total 8 chips)

memory using 1.7 GiB / 15.7 GiB (10.765424%) on TPU_4(process=0,(0,2,0,0))
memory using 1.7 GiB / 15.7 GiB (10.765424%) on TPU_5(process=0,(1,2,0,0))
memory using 1.7 GiB / 15.7 GiB (10.765424%) on TPU_6(process=0,(0,3,0,0))
memory using 1.7 GiB / 15.7 GiB (10.765424%) on TPU_7(process=0,(1,3,0,0))

Meaningful Result

With this PR, first two line of the result on v5e-16 (4 host , total 16 chips). The result are meaningful and looking good to human.


to find purpose and fulfillment.

I believe that everyone has a unique purpose and that it is up to each individual to discover and pursue theirs.

first two line of the result on v5e-8 (4 host , total 8 chips)

to find purpose, happiness, and fulfillment. Here are some reasons why:

1. Purpose: Having a sense of purpose gives life meaning and direction. It helps individuals set goals and work towards achieving them, which can lead to a sense of accomplishment and fulfillment.

Caveats

As this is the MVP, it's important to note that there could be some limitations in terms of performance and accuracy right now.

FanhaiLu1 commented 4 months ago

This looks great!! Mostly nits, but IIUC this same approach would work for JAX/JetStream right?

Thanks Allen for reviewing it! Yes, the Jax/Jetstream should use same approach in general, we could extract common part of code (For example: Master code could be exactly same for both Jax and Pytorch), so both Jax and Pytorch part can share same code base.

allenwang28 commented 4 months ago

Could we also add a basic unit test for Ray? One way is to take advantage of Ray's "fake" cluster if you look at this file for instance.

FanhaiLu1 commented 4 months ago

Could we also add a basic unit test for Ray? One way is to take advantage of Ray's "fake" cluster if you look at this file for instance.

Looks great with Ray's "fake" cluster test. Can I add test with another PR (I tested current PR e2e manually)? I also plan to add unit test for RayWorkers in next PR.

allenwang28 commented 4 months ago

Looks great with Ray's "fake" cluster test. Can I add test with another PR (I tested current PR e2e manually)? I also plan to add unit test for RayWorkers in next PR.

Sounds good to me! Thanks for patiently addressing all of the comments, LGTM!