Open vadimkantorov opened 2 years ago
Should something like below work for wrapping ResNet's last layer (Neck)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)
Neck
import torch import torch.nn as nn class SequentialBackprop(nn.Module): def __init__(self, module, batch_size = 1): super().__init__() self.module = module self.batch_size = batch_size def forward(self, x): y = self.module(x.detach()) return self.Function.apply(x, y, self.batch_size, self.module) class Function(torch.autograd.Function): @staticmethod def forward(ctx, x, y, batch_size, module): ctx.save_for_backward(x) ctx.batch_size = batch_size ctx.module = module return y @staticmethod def backward(ctx, grad_output): (x,) = ctx.saved_tensors grads = [] for x_mini, g_mini in zip(x.split(ctx.batch_size), grad_output.split(ctx.batch_size)): with torch.enable_grad(): x_mini = x_mini.detach().requires_grad_() x_mini.retain_grad() y_mini = ctx.module(x_mini) torch.autograd.backward(y_mini, g_mini) grads.append(x_mini.grad) return torch.cat(grads), None, None, None if __name__ == '__main__': backbone = nn.Linear(3, 6) neck = nn.Linear(6, 12) head = nn.Linear(12, 1) model = nn.Sequential(backbone, SequentialBackprop(neck, batch_size = 16), head) print('before', neck.weight.grad) x = torch.rand(512, 3) model(x).sum().backward() print('after', neck.weight.grad)
Hello vadimkantorov! I'm trying to implement this module recently, wondering whether your SBP code can work or not? Or there may need more modification? Would be grateful if any help is provided!
Should something like below work for wrapping ResNet's last layer (
Neck
)? (https://gist.github.com/vadimkantorov/67fe785ed0bf31727af29a3584b87be1)