Closed singhsarvagya closed 4 years ago
Unfortunately, we've never considered the second-order derivation. Especially, checkpointing would not create a graph for it correctly. Therefore, the implementation may not be optimal in this case. Hence, we do not recommend using torchgpipe with the second-order derivation.
Thank you!
I just wanted to let you know that I was able to verify that the torchgpipe perfectly calculates the higher-order derivatives despite the checkpointing. It only requires retaining the graph during the backward pass.
Oh, it's interesting. Did you compare the calculation with and without GPipe?
Yes! And they give the same results. There were minors variations, which can be due to floating-point operations.
That's weird. When I tried with checkpoint='always'
the second-order derivation causes RuntimeError
:
import torch
from torch import nn
from torchgpipe import GPipe
torch.manual_seed(0)
x = torch.randn(4, 1, requires_grad=True)
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1))
gpipe = GPipe(model, [2], chunks=2, devices=['cpu'], checkpoint='always')
gpipe(x).norm().backward(create_graph=True)
print(torch.autograd.grad([p.grad for p in gpipe.parameters()], gpipe.parameters(), grad_outputs=[torch.randn_like(p) for p in gpipe.parameters()]))
# Traceback (most recent call last):
# ...
# RuntimeError: element 0 of tensors does not require grad and does not have a grad_fn
I expect that the second-order graph cannot be created through checkpoints during backpropagation. I guess the variations you observed would not be minor. How did you set chunks
and checkpoint
options? Note that the default checkpoint
option is except_last
. This option turns on checkpoints for every chunk except last chunk.
I made a minor change and made model eval, and it works, but I don't understand why..
import torch
from torch import nn
from torchgpipe import GPipe
torch.manual_seed(0)
x = torch.randn(4, 1, requires_grad=True)
model = nn.Sequential(nn.Linear(1, 1), nn.Linear(1, 1))
gpipe = GPipe(model, [2], chunks=2, devices=['cpu'], checkpoint='always')
gpipe.eval()
gpipe(x).norm().backward(create_graph=True)
print(torch.autograd.grad([p.grad for p in gpipe.parameters()], gpipe.parameters(), grad_outputs=[torch.randn_like(p) for p in gpipe.parameters()]))
I got the following result:
(tensor([[0.1589]]), tensor([1.4791]), tensor([[-1.3312]]), tensor([-0.1750]))
Currently, torchgpipe.GPipe
does not apply checkpointing on the eval mode regardless of the checkpoint
option. I think that that's why you don't meet any problem.
# The micro-batch index where the checkpointing stops.
if self.training:
checkpoint_stop = {
'always': self.chunks,
'except_last': self.chunks-1,
'never': 0,
}[self.checkpoint]
else:
checkpoint_stop = 0
However, this behavior seems not to be valid. We have assumed there will not be backpropagations on the eval mode. But this assumption is wrong as you show us.
As far as my understanding of PyTorch goes, eval mode only changes the behavior of certain layers such as batch norm. One can still do backward pass and run optimizers.
So are you suggesting that in eval mode, there is no micro-batch level parallelism in the backward propogation?
Without checkpointing, torchgpipe.GPipe
works like the typical pipeline parallelism (= micro-batch level parallelism). So even in the eval mode, there should be parallelism in the backpropagation.
Firstly thank you for this amazing library!!
I was trying to utilize this library for one of my projects where I would like to do a Hessian analysis of the model in a distributed setting. I understand that you were able to resolve the issue with the backpropagation in gpipe by establishing a relationship between the micro-batches. Since calculating any second-order derivative is same as doing a second backpropagation through the computational graph, do you think it will be safe to do this operation using GPipe?
Thank you.