coreweave / ml-containers

MIT License
17 stars 3 forks source link

[feature-request] Support for JAX container #42

Open sbhavani opened 9 months ago

sbhavani commented 9 months ago

Concise Description: I'd like to use JAX for distributed training of LLMs. In addition, the new release of Keras supports JAX as a backend in addition to TF.

Describe the solution you'd like I'd like either a separate JAX container or jaxlib included in a TF container since the TF ecosystem (data loading, serving, etc) supports JAX.

Describe alternatives you've considered I could install JAX on top of the PyT container.

sbhavani commented 2 months ago

Any updates on a JAX container?