I am looking to train a very serious non-LLM model, and the training is expected to be very hard, so I am looking for maximum speed.
I know that Google's TPUs are said to be the fastest for training and inference - some 197 TFlops at only $0.6/hr with interruptible (spot) pricing.
Is this library TPU-optimized? Is it much faster than the other existing libraries? What to compare it against (aside from what is said in the blogpost?)
I haven't thoroughly tested this on TPU's; however, I don't think there is any hardware-specific/relevant details in this implementation. It should work faster than other non-pure JAX libraries.
Hello everyone, practitioner here,
I am looking to train a very serious non-LLM model, and the training is expected to be very hard, so I am looking for maximum speed.
I know that Google's TPUs are said to be the fastest for training and inference - some 197 TFlops at only $0.6/hr with interruptible (spot) pricing.
Is this library TPU-optimized? Is it much faster than the other existing libraries? What to compare it against (aside from what is said in the blogpost?)
Thanks, @MRiabov