Open danilojsl opened 3 years ago
We had a brief look into this several months ago, and it's basically a huge pile of Python (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/distribute) we'd need to replicate into Java. Doing out-of-band distribution using MPI allreduce directly from Java would be relatively easy, though have much worse performance than the native TF solution as it wouldn't do direct GPU - GPU copies.
Hi @Craigacp
Is this also true for inference? In prediction, each inference is isolated from the others so it seems easier to batch inputs and send them to multiple GPU devices at the same time in parallel. (just trying to see if inference over multiple GPU devices can happen in tensorflow-java)
Most of the ops are there (https://github.com/tensorflow/java/tree/55547dd20b14e1e9cd592a8789e780a0be3ae507/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective), I'm not sure if you can use them manually yet (we have no way to create groups or instances) or whether they can send gpu <-> gpu (w/ nvcc I assume).
Of course, if you have multiple GPUs on a single machine, you can just use the device settings.
Thanks @rnett
We should do some testings, I wasn't sure if I could load something like BERT (which the ops and device assignment are out of my hand and it's a SavedModel) and somehow use ConfigProto/Session to distribute over multiple GPU devices.
I'll see if the device settings/scope can be applied to a loaded SavedModel.
That's more a TF-core thing, I don't know if there's support for it, although it seems like a common enough use case. If you don't find a way, you may be able to create an ConcreteFunction
out of your inference call, and then call that on different GPUs.
I'll give it a shot to see if I can send each partition of inputs on a different available GPU device even in a simple round-robin can help.
Thanks again @rnett, since the explosion of pretrained models for TF this may become a feature in tensorflow-java one day
Is this also true for inference? In prediction, each inference is isolated from the others so it seems easier to batch inputs and send them to multiple GPU devices at the same time in parallel. (just trying to see if inference over multiple GPU devices can happen in tensorflow-java)
I wonder how easily we can do that with a proper inference server like Triton, whose efficient-but-not-too-user-friendly-yet C API can be used from Java: https://github.com/bytedeco/javacpp-presets/tree/master/tritonserver
@jackyh What do you think?
If we're talking about different packages you can use ONNX Runtime sessions one per GPU in the same JVM. But let's get it working properly in TF-Java.
A while ago there was a compilation issue when building TF-Java on multiple GPUs as the device placement algorithm got a bit confused by some of the optimisers. I guess I should check if that's still true now we've upgraded TF multiple times.
Please make sure that this is a feature request. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template
System information
Describe the feature and the current behavior/state. Tensorflow on Python has tf.distribute.Strategy API to distribute training across multiple GPUs or multiple machines.
Will this change the current api? How? Yes, it will add a new awesome feature
Who will benefit with this feature?
Any Other info. https://www.tensorflow.org/guide/distributed_training