google / jetstream-pytorch

PyTorch/XLA integration with JetStream (https://github.com/google/JetStream) for LLM inference"
Apache License 2.0
21 stars 12 forks source link

Fix ray conflict changes #100

Closed FanhaiLu1 closed 1 month ago

FanhaiLu1 commented 1 month ago

This PR add below changes:

1: move torch_xla2.default_env() to function. jax_mode = torch_xla2.default_env() block jax multiple controller in init state 2: ray engine create is different than default run server one, it will have prefill and decode engines later 3: removed duplciated JetEngineEnvironment 4: Not support shard_on_batch and ragged attention in ray multiple for now

wang2yn84 commented 1 month ago

Can you help me understand how jax_mode = torch_xla2.default_env() block jax multiple controller in init state?

Can you help me understand why is it?

FanhaiLu1 commented 1 month ago

Can you help me understand how jax_mode = torch_xla2.default_env() block jax multiple controller in init state?

Can you help me understand why is it?

The is jax call under this function ( or deeper). For any jax function call, it will try to init the multiple controller env (though MPI barrier), which mean need to wait all the chips finished. So in ray multiple host, if there is a jax function call in head node, it will wait all the chips be ready, but only the head node chip is ready at this time, the the whole application will stuck there.

For current use case, it happens when Ray head load the class even, it call the jax and stuck there even before start execute main function.