Closed remyhuang closed 3 years ago
Hi, Could you briefly describe your network architecture?
The model I used is a quite large model which contains 5 sub-networks (5 difference inputs, co-training) with 3 multi-tasks.
brief diagram: input_1 -> net_1 -> task_1/task_2/task_3 input_2-> net_2 -> task_1/task_2/task_3 input_3 -> net_3 -> task_1/task_2/task_3 input_4 -> net_4 -> task_1/task_2/task_3 input_5 -> net_5 -> task_1/task_2/task_3
so, it contains total 3*5=15 losses. When not using your code, the GPU memory has been occupied by nearly 30Gb in one V100.
I see. One possible solution is to execute forward and backward pass each task separately.
To be more specific, the forward pass is jointly executed for all tasks in my implementation, and the backward pass for each task is calculated separately. The benefit of this implementation is computationally efficient, but the network may occupy more memory space.
You can change the implementation if you would like to save some GPU memory.
You can first calculate the objective and gradient for a task and save the gradient to a dictionary. Then, you clear the gradient graph and apply the same procedure for another task. In this implementation, since you clear the gradient graph (retain_graph=False
), the GPU memory can be saved.
brief diagram for this implementation: input_1 -> net_1 -> task_1 ==> calculate grad input_1 -> net_1 -> task_2 ==> calculate grad input_1 -> net_1 -> task_3 ==> calculate grad input_2-> net_2 -> task_1 ==> calculate grad input_2-> net_2 -> task_2 ==> calculate grad input_2-> net_2 -> task_3 ==> calculate grad
I hope this information can help you.
Thanks for your hint. I have revised your toy example, but I still got the out-of-memory problem on my model. It would be a great help if you can help me confirm that my revised code is correct!
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pdb
import numpy as np
import copy
import random
class PCGrad():
def __init__(self, optimizer):
self._optim = optimizer
self.grads = []
self.shapes = []
self.has_grads = []
@property
def optimizer(self):
return self._optim
def zero_grad(self):
return self._optim.zero_grad(set_to_none=True)
def step(self):
return self._optim.step()
def _pc_backward(self, objective):
self._optim.zero_grad(set_to_none=True)
objective.backward(retain_graph=False)
grad, shape, has_grad = self._retrieve_grad()
self.grads.append(self._flatten_grad(grad, shape))
self.has_grads.append(self._flatten_grad(has_grad, shape))
self.shapes.append(shape)
def pc_backward(self):
pc_grad = self._project_conflicting(self.grads, self.has_grads)
pc_grad = self._unflatten_grad(pc_grad, self.shapes[0])
self._set_grad(pc_grad)
def _project_conflicting(self, grads, has_grads, shapes=None):
shared = torch.stack(has_grads).prod(0).bool()
pc_grad, num_task = copy.deepcopy(grads), len(grads)
for g_i in pc_grad:
random.shuffle(grads)
for g_j in grads:
g_i_g_j = torch.dot(g_i, g_j)
if g_i_g_j < 0:
g_i -= (g_i_g_j) * g_j / (g_j.norm()**2)
merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
merged_grad[shared] = torch.stack([g[shared]
for g in pc_grad]).mean(dim=0)
merged_grad[~shared] = torch.stack([g[~shared]
for g in pc_grad]).sum(dim=0)
return merged_grad
def _set_grad(self, grads):
idx = 0
for group in self._optim.param_groups:
for p in group['params']:
# if p.grad is None: continue
p.grad = grads[idx]
idx += 1
return
def _unflatten_grad(self, grads, shapes):
unflatten_grad, idx = [], 0
for shape in shapes:
length = np.prod(shape)
unflatten_grad.append(grads[idx:idx + length].view(shape).clone())
idx += length
return unflatten_grad
def _flatten_grad(self, grads, shapes):
flatten_grad = torch.cat([g.flatten() for g in grads])
return flatten_grad
def _retrieve_grad(self):
grad, shape, has_grad = [], [], []
for group in self._optim.param_groups:
for p in group['params']:
# if p.grad is None: continue
if p.grad is None:
shape.append(p.shape)
grad.append(torch.zeros_like(p).to(p.device))
has_grad.append(torch.zeros_like(p).to(p.device))
continue
shape.append(p.grad.shape)
grad.append(p.grad.clone())
has_grad.append(torch.ones_like(p).to(p.device))
return grad, shape, has_grad
class TestNet(nn.Module):
def __init__(self):
super().__init__()
self.linear = nn.Linear(3, 4)
self.linear_1 = nn.Linear(4, 5)
self.linear_2 = nn.Linear(4, 5)
def task_1(self, x):
x = self.linear(x)
x = self.linear_1(x)
return x
def task_2(self, x):
x = self.linear(x)
x = self.linear_2(x)
return x
if __name__ == '__main__':
a = torch.rand(16, 3)
b = torch.rand(16, 3)
c = torch.rand(16, 5)
model = TestNet()
loss_fn = nn.MSELoss()
pc_adam = PCGrad(optim.Adam(model.parameters(), 0.01))
pc_adam.zero_grad()
# task 1
out_1 = model.task_1(a)
loss_1 = loss_fn(out_1, c)
pc_adam._pc_backward(loss_1)
# task 2
out_2 = model.task_2(b)
loss_2 = loss_fn(out_2, c)
pc_adam._pc_backward(loss_2)
# backward
pc_adam.pc_backward()
pc_adam.step()
The script seems correct at first glance.
I am curious about which part causes the OOM during the training process. Does it happen in the backward phase or the conflict gradient calculation? More Traceback
could be helpful for us to figure it out.
It happens during the forward calculation. (take the example mentioned above, the model causes the OOM when calculating input_5 -> net_5 -> task_2.) It feels like the grads keep being saved.
Cool! If you don't apply PCGrad and jointly train the network with all the losses, how large memory space does the training process occupy?
I just found out whether using normal loss, PCGrad or not calculating any loss, it will cause OOM. I guess it may be caused by the change of the model forward-method? input_1 -> net_1 -> task_1 --> saved memory? input_1 -> net_1 -> task_2 --> saved memory? input_1 -> net_1 -> task_3 --> saved memory?
I will keep checking the code and try to find the problem. Thanks for your help!
Hi, did you solve the problem? I got the same problem and I'm really appreciate if you could tell me some tricks to solve it.
Hi,
thanks for your great implementation! I was trying to apply your code to my own model, but I bumped into the CUDA out of memory problem. It seems to be due to the "retain_graph=True". How to fix it?
Thanks a lot again.