SeldonIO / alibi

Algorithms for explaining machine learning models
https://docs.seldon.io/projects/alibi/en/stable/
Other
2.42k stars 252 forks source link

`GradSim` Improvements #837

Open mauicv opened 1 year ago

mauicv commented 1 year ago

I think the dash onsite demonstrated the GradSim method is slow for large models. This is because currently, pytorch and tensorflow don’t let you compute gradients per instance in a batch which gradient similarity requires. We can do this before time by storing the gradients but this becomes impossible for large models. Note that partial solutions include: a) using a subset of model weights, such as a final layer, to decrease memory overhead or b) reducing the dataset you're comparing against using something like ProtoSelect. Both of these are user-level interventions. I think our focus should be figuring out how to batch the gradient computations.

mauicv commented 1 year ago

I think differential privacy libraries such as Opacus may do what we need. This blog post details how they perform efficient per-sample gradient computation.

mauicv commented 1 year ago

Example of using Opacus to comupte per-sample gradient computation:

!pip install -q opacus

from opacus.grad_sample import GradSampleModule

import torch
import torch.nn as nn
import torch.nn.functional as F

class MNISTConvNet(nn.Module):
    def __init__(self):
        super(MNISTConvNet, self).__init__()
        self.conv1 = nn.Conv2d(1, 10, 5)
        self.pool1 = nn.MaxPool2d(2, 2)
        self.conv2 = nn.Conv2d(10, 20, 5)
        self.pool2 = nn.MaxPool2d(2, 2)
        self.fc1 = nn.Linear(320, 50)
        self.fc2 = nn.Linear(50, 10)

    def forward(self, input):
        x = self.pool1(F.relu(self.conv1(input)))
        x = self.pool2(F.relu(self.conv2(x)))
        x = x.view(x.size(0), -1)
        x = F.relu(self.fc1(x))
        x = F.relu(self.fc2(x))
        return x

net = GradSampleModule(MNISTConvNet())

input = torch.randn((2, 1, 28, 28))
target = torch.randn((2)).to(torch.long)
loss_fn = nn.CrossEntropyLoss()

out = net(input)
err = loss_fn(out, target)
err.backward()

for p in net.parameters():
  print(p.grad_sample.shape)

Note a limitation of this approach is that opacus only overrides a finite number of torch module types so we might potentially end up with limitations.

mauicv commented 1 year ago

@jklaise & @RobertSamoilescu, tagging for discussion, any thoughts?