Concise Description:
I'd like to use JAX for distributed training of LLMs. In addition, the new release of Keras supports JAX as a backend in addition to TF.
Describe the solution you'd like
I'd like either a separate JAX container or jaxlib included in a TF container since the TF ecosystem (data loading, serving, etc) supports JAX.
Describe alternatives you've considered
I could install JAX on top of the PyT container.
Concise Description: I'd like to use JAX for distributed training of LLMs. In addition, the new release of Keras supports JAX as a backend in addition to TF.
Describe the solution you'd like I'd like either a separate JAX container or jaxlib included in a TF container since the TF ecosystem (data loading, serving, etc) supports JAX.
Describe alternatives you've considered I could install JAX on top of the PyT container.