WeiChengTseng / Pytorch-PCGrad

Pytorch reimplementation for "Gradient Surgery for Multi-Task Learning"
BSD 3-Clause "New" or "Revised" License
302 stars 42 forks source link

CUDA out of memory #2

Closed remyhuang closed 3 years ago

remyhuang commented 3 years ago

Hi,

thanks for your great implementation! I was trying to apply your code to my own model, but I bumped into the CUDA out of memory problem. It seems to be due to the "retain_graph=True". How to fix it?

Thanks a lot again.

WeiChengTseng commented 3 years ago

Hi, Could you briefly describe your network architecture?

remyhuang commented 3 years ago

The model I used is a quite large model which contains 5 sub-networks (5 difference inputs, co-training) with 3 multi-tasks.

brief diagram: input_1 -> net_1 -> task_1/task_2/task_3 input_2-> net_2 -> task_1/task_2/task_3 input_3 -> net_3 -> task_1/task_2/task_3 input_4 -> net_4 -> task_1/task_2/task_3 input_5 -> net_5 -> task_1/task_2/task_3

so, it contains total 3*5=15 losses. When not using your code, the GPU memory has been occupied by nearly 30Gb in one V100.

WeiChengTseng commented 3 years ago

I see. One possible solution is to execute forward and backward pass each task separately.

To be more specific, the forward pass is jointly executed for all tasks in my implementation, and the backward pass for each task is calculated separately. The benefit of this implementation is computationally efficient, but the network may occupy more memory space.

You can change the implementation if you would like to save some GPU memory. You can first calculate the objective and gradient for a task and save the gradient to a dictionary. Then, you clear the gradient graph and apply the same procedure for another task. In this implementation, since you clear the gradient graph (retain_graph=False), the GPU memory can be saved.

brief diagram for this implementation: input_1 -> net_1 -> task_1 ==> calculate grad input_1 -> net_1 -> task_2 ==> calculate grad input_1 -> net_1 -> task_3 ==> calculate grad input_2-> net_2 -> task_1 ==> calculate grad input_2-> net_2 -> task_2 ==> calculate grad input_2-> net_2 -> task_3 ==> calculate grad

I hope this information can help you.

remyhuang commented 3 years ago

Thanks for your hint. I have revised your toy example, but I still got the out-of-memory problem on my model. It would be a great help if you can help me confirm that my revised code is correct!

import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import pdb
import numpy as np
import copy
import random

class PCGrad():
    def __init__(self, optimizer):
        self._optim = optimizer
        self.grads = []
        self.shapes = []
        self.has_grads = []

    @property
    def optimizer(self):
        return self._optim

    def zero_grad(self):
        return self._optim.zero_grad(set_to_none=True)

    def step(self):
        return self._optim.step()

    def _pc_backward(self, objective):
        self._optim.zero_grad(set_to_none=True)
        objective.backward(retain_graph=False)
        grad, shape, has_grad = self._retrieve_grad()
        self.grads.append(self._flatten_grad(grad, shape))
        self.has_grads.append(self._flatten_grad(has_grad, shape))
        self.shapes.append(shape)

    def pc_backward(self):
        pc_grad = self._project_conflicting(self.grads, self.has_grads)
        pc_grad = self._unflatten_grad(pc_grad, self.shapes[0])
        self._set_grad(pc_grad)

    def _project_conflicting(self, grads, has_grads, shapes=None):
        shared = torch.stack(has_grads).prod(0).bool()
        pc_grad, num_task = copy.deepcopy(grads), len(grads)
        for g_i in pc_grad:
            random.shuffle(grads)
            for g_j in grads:
                g_i_g_j = torch.dot(g_i, g_j)
                if g_i_g_j < 0:
                    g_i -= (g_i_g_j) * g_j / (g_j.norm()**2)
        merged_grad = torch.zeros_like(grads[0]).to(grads[0].device)
        merged_grad[shared] = torch.stack([g[shared]
                                           for g in pc_grad]).mean(dim=0)
        merged_grad[~shared] = torch.stack([g[~shared]
                                            for g in pc_grad]).sum(dim=0)
        return merged_grad

    def _set_grad(self, grads):
        idx = 0
        for group in self._optim.param_groups:
            for p in group['params']:
                # if p.grad is None: continue
                p.grad = grads[idx]
                idx += 1
        return

    def _unflatten_grad(self, grads, shapes):
        unflatten_grad, idx = [], 0
        for shape in shapes:
            length = np.prod(shape)
            unflatten_grad.append(grads[idx:idx + length].view(shape).clone())
            idx += length
        return unflatten_grad

    def _flatten_grad(self, grads, shapes):
        flatten_grad = torch.cat([g.flatten() for g in grads])
        return flatten_grad

    def _retrieve_grad(self):
        grad, shape, has_grad = [], [], []
        for group in self._optim.param_groups:
            for p in group['params']:
                # if p.grad is None: continue
                if p.grad is None:
                    shape.append(p.shape)
                    grad.append(torch.zeros_like(p).to(p.device))
                    has_grad.append(torch.zeros_like(p).to(p.device))
                    continue
                shape.append(p.grad.shape)
                grad.append(p.grad.clone())
                has_grad.append(torch.ones_like(p).to(p.device))
        return grad, shape, has_grad

class TestNet(nn.Module):
    def __init__(self):
        super().__init__()
        self.linear = nn.Linear(3, 4)
        self.linear_1 = nn.Linear(4, 5)
        self.linear_2 = nn.Linear(4, 5)

    def task_1(self, x):
        x = self.linear(x)
        x = self.linear_1(x)
        return x

    def task_2(self, x):
        x = self.linear(x)
        x = self.linear_2(x)
        return x

if __name__ == '__main__':
    a = torch.rand(16, 3)
    b = torch.rand(16, 3)
    c = torch.rand(16, 5)
    model = TestNet()
    loss_fn = nn.MSELoss()
    pc_adam = PCGrad(optim.Adam(model.parameters(), 0.01))
    pc_adam.zero_grad()
    # task 1
    out_1 = model.task_1(a)
    loss_1 = loss_fn(out_1, c)
    pc_adam._pc_backward(loss_1)
    # task 2
    out_2 = model.task_2(b)
    loss_2 = loss_fn(out_2, c)
    pc_adam._pc_backward(loss_2)
    # backward
    pc_adam.pc_backward()
    pc_adam.step()
WeiChengTseng commented 3 years ago

The script seems correct at first glance.

WeiChengTseng commented 3 years ago

I am curious about which part causes the OOM during the training process. Does it happen in the backward phase or the conflict gradient calculation? More Traceback could be helpful for us to figure it out.

remyhuang commented 3 years ago

It happens during the forward calculation. (take the example mentioned above, the model causes the OOM when calculating input_5 -> net_5 -> task_2.) It feels like the grads keep being saved.

WeiChengTseng commented 3 years ago

Cool! If you don't apply PCGrad and jointly train the network with all the losses, how large memory space does the training process occupy?

remyhuang commented 3 years ago

I just found out whether using normal loss, PCGrad or not calculating any loss, it will cause OOM. I guess it may be caused by the change of the model forward-method? input_1 -> net_1 -> task_1 --> saved memory? input_1 -> net_1 -> task_2 --> saved memory? input_1 -> net_1 -> task_3 --> saved memory?

I will keep checking the code and try to find the problem. Thanks for your help!

sydat2701 commented 9 months ago

Hi, did you solve the problem? I got the same problem and I'm really appreciate if you could tell me some tricks to solve it.