shivamsaboo17 / Overcoming-Catastrophic-forgetting-in-Neural-Networks

Elastic weight consolidation technique for incremental learning.
124 stars 22 forks source link

Improve the efficiency of GPU memory utilization #7

Open ThomasAtlantis opened 10 months ago

ThomasAtlantis commented 10 months ago

When we use larger model (e.g. VGG19) and larger batch size (e.g. 256), the original version of _update_fisher_params will easily deplete the GPU memory (over 24GB). Here, I propose a practical improved version.

def _update_fisher_params(self, dataloader):
    """
    Fisher Information Matrix: I = E_x[ (d log p(x|w) / dw)^2 ]
    """
    self.model.train()

    names, params_actual = zip(*self.model.named_parameters())

    def func(*params_formal, _input, _label):
        _output = functional_call(
            self.model, {n: p for n, p in zip(names, params_formal)}, _input)
        _output = F.log_softmax(_output, dim=1)
        _output = torch.gather(_output, dim=1, index=_label.unsqueeze(-1))
        return _output

    total = 0
    grad_log_likelihood = [torch.zeros_like(param) for param in self.model.parameters()]
    for inputs, targets in dataloader:
        inputs, targets = inputs.cuda(), targets.cuda()
        jacobi = autograd.functional.jacobian(
            functools.partial(func, _input=inputs, _label=targets), params_actual, vectorize=True)
        for i, grad in enumerate(jacobi):
            grad_log_likelihood[i] += torch.sum(grad.detach() ** 2, dim=0).squeeze_()
        total += targets.shape[0]

    for name, param in zip(names, grad_log_likelihood):
        self.model.register_buffer(name.replace('.', '__') + '_estimated_fisher', param / total)
ThomasAtlantis commented 10 months ago

To obtain sum of squared gradients instead of squared sum of gradients, we need to compute the gradient of log likelihood wrt each parameter. This process requires a sequential execution on each data sample, since we cannot use the reduce_mean or reduce_sum commonly adopted in loss design.

Here, I adopt autograd.functional.jacobian() to compute the Jacobian matrix in parallel. As this function is usually used to compute the gradients wrt the inputs instead of neural network parameters, we need to construct a forward function using parameters as a pseudo input. I learned about this trick from here.