kakaobrain / torchgpipe

A GPipe implementation in PyTorch
https://torchgpipe.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
813 stars 98 forks source link

A more memory-efficient implementation #19

Closed anxuthu closed 4 years ago

anxuthu commented 4 years ago

Hi,

I came across a paper called DAPPLE (url) about a more memory-efficient implementation of GPipe like this:

Capture

There seems to be some typo in the legend of (c). However, I see that you have implemented GPipe by building dependencies in the forward pass of micro-batches, and then the PyTorch autograd will automatically schedule the backward pass. So is it impossible to use PyTorch to implement DAPPLE, where the forward and backward pass are mixed?

Thank you for your help!

chiheonk commented 4 years ago

Hi @anxuthu ,

You are correct. It is impossible to implement DAPPLE in the way that torchgpipe does, which completely relies on the PyTorch's autograd engine for computing the backward pass. We introduced some dependency tricks to control the backward pass timeline, but this assumes that backward pass is only called via a tensor which is connected to the result of forward pass by the autograd graph (hence, forward pass must be completed beforehand). This gives the flexibility that user can do whatever they want in between forward pass and backward pass, as in usual PyTorch training code.

On the other hand, to implement complex pipeline algorithms such as DAPPLE, one may need to regard forward-backward pass as one big computation graph. I believe that one may completely isolate each computation task's (corresponding to each microbatch-partition pair) graph and design a manual pipelining mechanism to compute each task individually (say, using torch.autograd.backward(tensors=..., grad_tensors=...) for backward) to implement algorithms such as DAPPLE. I expect that this would require extra care for unwanted slow downs which may be caused by synchronizations or GIL in python.