kakaobrain / torchgpipe

A GPipe implementation in PyTorch
https://torchgpipe.readthedocs.io/
BSD 3-Clause "New" or "Revised" License
800 stars 98 forks source link

Using GPipe for Hessian Computation #13

Closed singhsarvagya closed 4 years ago

singhsarvagya commented 4 years ago

Firstly thank you for this amazing library!!

I was trying to utilize this library for one of my projects where I would like to do a Hessian analysis of the model in a distributed setting. I understand that you were able to resolve the issue with the backpropagation in gpipe by establishing a relationship between the micro-batches. Since calculating any second-order derivative is same as doing a second backpropagation through the computational graph, do you think it will be safe to do this operation using GPipe?

Thank you.

sublee commented 4 years ago

Unfortunately, we've never considered the second-order derivation. Especially, checkpointing would not create a graph for it correctly. Therefore, the implementation may not be optimal in this case. Hence, we do not recommend using torchgpipe with the second-order derivation.

singhsarvagya commented 4 years ago

Thank you!

singhsarvagya commented 4 years ago

I just wanted to let you know that I was able to verify that the torchgpipe perfectly calculates the higher-order derivatives despite the checkpointing. It only requires retaining the graph during the backward pass.

sublee commented 4 years ago

Oh, it's interesting. Did you compare the calculation with and without GPipe?

singhsarvagya commented 4 years ago

Yes! And they give the same results. There were minors variations, which can be due to floating-point operations.

sublee commented 4 years ago

That's weird. When I tried with checkpoint='always' the second-order derivation causes RuntimeError:

import torch                                                                                        
from torch import nn                                                                                
from torchgpipe import GPipe                                                                        

torch.manual_seed(0)                                                                                

x = torch.randn(4, 1, requires_grad=True)                                                           
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1))                                             
gpipe = GPipe(model, [2], chunks=2, devices=['cpu'], checkpoint='always')                           

gpipe(x).norm().backward(create_graph=True)                                                         
print(torch.autograd.grad([p.grad for p in gpipe.parameters()], gpipe.parameters(), grad_outputs=[torch.randn_like(p) for p in gpipe.parameters()]))

# Traceback (most recent call last):
#   ...
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn

I expect that the second-order graph cannot be created through checkpoints during backpropagation. I guess the variations you observed would not be minor. How did you set chunks and checkpoint options? Note that the default checkpoint option is except_last. This option turns on checkpoints for every chunk except last chunk.

singhsarvagya commented 4 years ago

I made a minor change and made model eval, and it works, but I don't understand why..

import torch
from torch import nn
from torchgpipe import GPipe

torch.manual_seed(0)

x = torch.randn(4, 1, requires_grad=True)
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1))
gpipe = GPipe(model, [2], chunks=2, devices=['cpu'], checkpoint='always')

gpipe.eval()
gpipe(x).norm().backward(create_graph=True)
print(torch.autograd.grad([p.grad for p in gpipe.parameters()], gpipe.parameters(), grad_outputs=[torch.randn_like(p) for p in gpipe.parameters()]))

I got the following result: (tensor([[0.1589]]), tensor([1.4791]), tensor([[-1.3312]]), tensor([-0.1750]))

sublee commented 4 years ago

Currently, torchgpipe.GPipe does not apply checkpointing on the eval mode regardless of the checkpoint option. I think that that's why you don't meet any problem.

# The micro-batch index where the checkpointing stops.
if self.training:                                     
    checkpoint_stop = {                               
        'always': self.chunks,                        
        'except_last': self.chunks-1,                 
        'never': 0,                                   
    }[self.checkpoint]                                
else:                                                 
    checkpoint_stop = 0                               

However, this behavior seems not to be valid. We have assumed there will not be backpropagations on the eval mode. But this assumption is wrong as you show us.

singhsarvagya commented 4 years ago

As far as my understanding of PyTorch goes, eval mode only changes the behavior of certain layers such as batch norm. One can still do backward pass and run optimizers.

So are you suggesting that in eval mode, there is no micro-batch level parallelism in the backward propogation?

sublee commented 4 years ago

Without checkpointing, torchgpipe.GPipe works like the typical pipeline parallelism (= micro-batch level parallelism). So even in the eval mode, there should be parallelism in the backpropagation.