tensorflow / java

Java bindings for TensorFlow
Apache License 2.0
823 stars 202 forks source link

Distributed Training with TensorFlow Java #369

Open danilojsl opened 3 years ago

danilojsl commented 3 years ago

Please make sure that this is a feature request. As per our GitHub Policy, we only address code/doc bugs, performance issues, feature requests and build/installation issues on GitHub. tag:feature_template

System information

Describe the feature and the current behavior/state. Tensorflow on Python has tf.distribute.Strategy API to distribute training across multiple GPUs or multiple machines.

Will this change the current api? How? Yes, it will add a new awesome feature

Who will benefit with this feature?

Any Other info. https://www.tensorflow.org/guide/distributed_training

Craigacp commented 3 years ago

We had a brief look into this several months ago, and it's basically a huge pile of Python (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/python/distribute) we'd need to replicate into Java. Doing out-of-band distribution using MPI allreduce directly from Java would be relatively easy, though have much worse performance than the native TF solution as it wouldn't do direct GPU - GPU copies.

maziyarpanahi commented 3 years ago

Hi @Craigacp

Is this also true for inference? In prediction, each inference is isolated from the others so it seems easier to batch inputs and send them to multiple GPU devices at the same time in parallel. (just trying to see if inference over multiple GPU devices can happen in tensorflow-java)

rnett commented 3 years ago

Most of the ops are there (https://github.com/tensorflow/java/tree/55547dd20b14e1e9cd592a8789e780a0be3ae507/tensorflow-core/tensorflow-core-api/src/gen/java/org/tensorflow/op/collective), I'm not sure if you can use them manually yet (we have no way to create groups or instances) or whether they can send gpu <-> gpu (w/ nvcc I assume).

Of course, if you have multiple GPUs on a single machine, you can just use the device settings.

maziyarpanahi commented 3 years ago

Thanks @rnett

We should do some testings, I wasn't sure if I could load something like BERT (which the ops and device assignment are out of my hand and it's a SavedModel) and somehow use ConfigProto/Session to distribute over multiple GPU devices.

I'll see if the device settings/scope can be applied to a loaded SavedModel.

rnett commented 3 years ago

That's more a TF-core thing, I don't know if there's support for it, although it seems like a common enough use case. If you don't find a way, you may be able to create an ConcreteFunction out of your inference call, and then call that on different GPUs.

maziyarpanahi commented 3 years ago

I'll give it a shot to see if I can send each partition of inputs on a different available GPU device even in a simple round-robin can help.

Thanks again @rnett, since the explosion of pretrained models for TF this may become a feature in tensorflow-java one day

saudet commented 3 years ago

Is this also true for inference? In prediction, each inference is isolated from the others so it seems easier to batch inputs and send them to multiple GPU devices at the same time in parallel. (just trying to see if inference over multiple GPU devices can happen in tensorflow-java)

I wonder how easily we can do that with a proper inference server like Triton, whose efficient-but-not-too-user-friendly-yet C API can be used from Java: https://github.com/bytedeco/javacpp-presets/tree/master/tritonserver

@jackyh What do you think?

Craigacp commented 3 years ago

If we're talking about different packages you can use ONNX Runtime sessions one per GPU in the same JVM. But let's get it working properly in TF-Java.

A while ago there was a compilation issue when building TF-Java on multiple GPUs as the device placement algorithm got a bit confused by some of the optimisers. I guess I should check if that's still true now we've upgraded TF multiple times.