Closed FanhaiLu1 closed 1 month ago
Can you help me understand how jax_mode = torch_xla2.default_env() block jax multiple controller in init state?
Can you help me understand why is it?
Can you help me understand how jax_mode = torch_xla2.default_env() block jax multiple controller in init state?
Can you help me understand why is it?
The is jax call under this function ( or deeper). For any jax function call, it will try to init the multiple controller env (though MPI barrier), which mean need to wait all the chips finished. So in ray multiple host, if there is a jax function call in head node, it will wait all the chips be ready, but only the head node chip is ready at this time, the the whole application will stuck there.
For current use case, it happens when Ray head load the class even, it call the jax and stuck there even before start execute main function.
This PR add below changes:
1: move torch_xla2.default_env() to function. jax_mode = torch_xla2.default_env() block jax multiple controller in init state 2: ray engine create is different than default run server one, it will have prefill and decode engines later 3: removed duplciated JetEngineEnvironment 4: Not support shard_on_batch and ragged attention in ray multiple for now