google / paxml

Pax is a Jax-based machine learning framework for training large scale models. Pax allows for advanced and fully configurable experimentation and parallelization, and has demonstrated industry leading model flop utilization rates.
Apache License 2.0
458 stars 69 forks source link

Add GPU scripts and dependencies #10

Closed ashors1 closed 1 year ago

ashors1 commented 1 year ago

This PR adds a contrib/gpu folder which contains the following

Additionally, optional GPU dependencies are added to setup.py.

Note that this PR depends on the Praxis PR which allows LayerNorm reductions to be computed in FP32.

zhangqiaorjc commented 1 year ago

This PR introduced many new files. I suspect I need to do a manual import and credit you as author instead.

zhangqiaorjc commented 1 year ago

i created placeholder files in https://github.com/google/paxml/tree/main/paxml/contrib/gpu

@ashors1 could you copy your PR to there? I didn't figure out how to do this automatically...

ashors1 commented 1 year ago

I've moved the files to paxml/contrib/gpu as requested. Please let me know if there are any issues with this