Open d1saster opened 4 years ago
Introducing "student project" label for potential thesis work.
Reviewed within #1109
Suggestion for a "prototype"
A good start could be to implement sth like diffable_Allreduce_SUM
, i.e. a AD-compatible counterpart of ht.comm.Allreduce( ... , op=MPI.SUM)
... this is maybe the simplemost case since (at least in my understanding) the structure of this function does not cause problems when reversing the direction of the DAG that would need to be catched with dummy-constructions etc.
I would suggest to try to define a custom autograd'able function diffable_Allreduce_SUM
with corresponding forward()
, backward()
and setup_context()
(?) as described here:
https://pytorch.org/tutorials/beginner/examples_autograd/two_layer_net_custom_function.html and in particular: https://pytorch.org/docs/stable/notes/extending.html
Regarding the derivatives (i.e. backward()
) for Allreduce
(with MPI.SUM
), see the great work of @d1saster https://github.com/helmholtz-analytics/mpi4torch/blob/master/csrc/extension.cpp. In fact, my suggestion for this issue is to progress similarly in Python as what has been done in C++ for mpi4torch
.
Next Steps after Prototype
Allgather
and Alltoall
should not infer problems with the computational graph as well (at least in my opinion) and therefore might be handeled in a similar way mpi4torch
: Allreduce
with MPI.PROD
or MPI.MAX/MIN
(?) Where it gets tricky
in principle this is extensively discussed in the docs of mpi4torch
and the above comment of @d1saster
Feature functionality It is planned to implement a thin provisiong layer above the existing MPI communication API that allows for nearly seemless integration of the MPI calls with the PyTorch AD system.
Additional context The rationale to implement AD is that it is one of the distingushing features of every ML library in comparison to a linear algebra library.
There are several issues that need to be resolved in order to provide an AD-capable MPI-like API. The PyTorch AD machinery, or AD itself, has strong assumptions on the functions one can calculate a derivative of.
recvbuf
arguments that violate the later part.MPI.IN_PLACE
functionality may also be difficult to support.MPIRequest
instances. As such, to make all asynchronous MPI calls AD-capable, and to properly reflect the causal dependency of the async MPI call and the corresponding Wait call in the DAG, one probably needs to encapsulate theMPIRequest
s in atorch.tensor
, as it has been done in the prototype.