luchris429 / purejaxrl

Really Fast End-to-End Jax RL Implementations
Apache License 2.0
738 stars 62 forks source link

[Question] Is PureJaxRL TPU-optimal? #27

Open MRiabov opened 5 months ago

MRiabov commented 5 months ago

Hello everyone, practitioner here,

I am looking to train a very serious non-LLM model, and the training is expected to be very hard, so I am looking for maximum speed.

I know that Google's TPUs are said to be the fastest for training and inference - some 197 TFlops at only $0.6/hr with interruptible (spot) pricing.

Is this library TPU-optimized? Is it much faster than the other existing libraries? What to compare it against (aside from what is said in the blogpost?)

Thanks, @MRiabov

luchris429 commented 5 months ago

I haven't thoroughly tested this on TPU's; however, I don't think there is any hardware-specific/relevant details in this implementation. It should work faster than other non-pure JAX libraries.