helmholtz-analytics / heat

Distributed tensors and Machine Learning framework with GPU and MPI acceleration in Python
https://heat.readthedocs.io/
MIT License
212 stars 53 forks source link

Implement AD functionality for MPI operations #482

Open d1saster opened 4 years ago

d1saster commented 4 years ago

Feature functionality It is planned to implement a thin provisiong layer above the existing MPI communication API that allows for nearly seemless integration of the MPI calls with the PyTorch AD system.

Additional context The rationale to implement AD is that it is one of the distingushing features of every ML library in comparison to a linear algebra library.

There are several issues that need to be resolved in order to provide an AD-capable MPI-like API. The PyTorch AD machinery, or AD itself, has strong assumptions on the functions one can calculate a derivative of.

ClaudiaComito commented 1 year ago

Introducing "student project" label for potential thesis work.


Reviewed within #1109

mrfh92 commented 1 year ago

Suggestion for a "prototype" A good start could be to implement sth like diffable_Allreduce_SUM, i.e. a AD-compatible counterpart of ht.comm.Allreduce( ... , op=MPI.SUM) ... this is maybe the simplemost case since (at least in my understanding) the structure of this function does not cause problems when reversing the direction of the DAG that would need to be catched with dummy-constructions etc.

I would suggest to try to define a custom autograd'able function diffable_Allreduce_SUM with corresponding forward(), backward() and setup_context() (?) as described here:

https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html and in particular: https://pytorch.org/docs/stable/notes/extending.html

Regarding the derivatives (i.e. backward()) for Allreduce (with MPI.SUM), see the great work of @d1saster https://github.com/helmholtz-analytics/mpi4torch/blob/master/csrc/extension.cpp. In fact, my suggestion for this issue is to progress similarly in Python as what has been done in C++ for mpi4torch.

Next Steps after Prototype

Where it gets tricky in principle this is extensively discussed in the docs of mpi4torch and the above comment of @d1saster