google / yggdrasil-decision-forests

A library to train, evaluate, interpret, and productionize decision forest models such as Random Forest and Gradient Boosted Decision Trees.
https://ydf.readthedocs.io/
Apache License 2.0
481 stars 52 forks source link

[Feature Request] GPU Acceleration #122

Open ZeroCool2u opened 2 months ago

ZeroCool2u commented 2 months ago

Hi There, it's unclear if Yggdrasil supports GPU or TPU acceleration. It seems like if you do fine tuning in JAX maybe it's possible when the model is converted to a JAX function? But it's not clear if that's intentional/expected or not.

rstz commented 2 months ago

Hi,

YDF does not support training using GPU or TPU acceleration yet. Our team has experimented in this direction, but we have not yet found a strong (business) incentive to productionize it. Please let us know if you need support and we'll be happy to discuss options.

When converted to a JAX function, the model can run on GPU or TPU (or CPU) for serving and/or fine-tuning. Note that the non-JAX inference on CPU can be quite fast (~1 microsecond) with the right model / configuration. If inference speed is the main concern, it's probably worth considering CPU inference first.

ZeroCool2u commented 2 months ago

Hi rstz, thanks for the info!

We're confident in CPU inference speed for YDF and would like to avoid GPU/TPU usage for economic reasons in that scenario.

This use case is more focused on time series style problems where standard cross validation isn't viable and we'd have to use something like TimeSeriesSplit for training and evaluation across multiple splits.

Thank you for answering our question, it's much appreciated!