pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.95k stars 499 forks source link

Implementing captum with pytorch-lightning #726

Closed ik362 closed 2 years ago

ik362 commented 3 years ago

❓ Questions and Help

Hi there, I am a new user to and I am trying to use LayerGradCam in captum to interpret a particular layer in my model.

Part of the problem/complication seems to be that my model and forward method are defined in a pytorch-lightning module.

My pytorch-lightning module is:

class model(pl.LightningModule):
    def __init__(self, learning_rate = float):
        super().__init__()
        self.learning_rate = learning_rate
        self.criterion = nn.BCEWithLogitsLoss()
        self.cam = LayerGradCam(self.forward, 'model.5')
        self.model = nn.Sequential(cnnBlock1(), cnnBlock2(), cnnBlock3(), linearBlock())

    def forward(self, x):
        return self.model(x)

    def train_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        train_loss = self.criterion(y_hat, y)
        self.log('train_loss', train_loss)
        return train_loss

    def validation_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        val_loss = self.criterion(y_hat, y)
        self.log('val_loss', val_loss)
        return val_loss

    def test_step(self, batch, batch_idx):
        x, y = batch
        y_hat = self.forward(x)
        attr = self.cam.attribute(x)
        return attr

    def configure_optimizers(self):
        optimizer = torch.optim.Adam(self.parameters(), lr = self.learning_rate
        return optimizer

However, when I run the test step I am getting the error:

AttributeError: 'str' object has no attribute 'register_forward_hook'

I have two questions then:

  1. What does this error mean and how do I fix it?
  2. How do I/what is best practice for implementing captum with pytorch-lightning?

Thanks for your help!

NarineK commented 3 years ago

Tagging same question on pytorch forums. https://discuss.pytorch.org/t/implementing-captum-with-pytorch-lightning/129292

aobo-y commented 3 years ago

hi @ik362 , sorry for my late reply.

I believe pytorch-lightning has nothing to do here. It will work as long as your model have a forward-like interface to pass into Captum.

The issue is caused by the 2nd argument in the following line.

self.cam = LayerGradCam(self.forward, 'model.5')

What is the string model.5? named module? is it defined in your blocks, e.g., linearBlock? Anyway, the 2nd argument layer should be the module itself, not a name. You can refer to our documentation for details https://github.com/pytorch/captum/blob/4faf1ea49fbff90af92b759c1f763dda1d8be705/captum/attr/_core/layer/grad_cam.py#L64-L67