pytorch / captum

Model interpretability and understanding for PyTorch
https://captum.ai
BSD 3-Clause "New" or "Revised" License
4.95k stars 499 forks source link

captum.concept.TCAV() custom classifier #1309

Open FilipUniBe opened 4 months ago

FilipUniBe commented 4 months ago

I posted this originally on StackOverflow, but now I thought it may fit better here. After all it is related to this setup.

I want to change the default classifier of the captum TCAV class with a custom own to do the calculation on the GPU instead and to batch it, because as it is it is slow. With the unchanged Captum setup one gets two warnings:

"UserWarning: Using default classifier for TCAV which keeps input both train and test datasets in the memory. Consider defining your own classifier that doesn't rely heavily on memory, for large number of concepts, by extending Classifer abstract class"

"UserWarning: Creating a tensor from a list of numpy.ndarrays is extremly slow. Please consider converting the list to a single numpy.ndarray with numpy.array() before converting to a tensor."

According to the Captum-wiki and its corresponding github-issue the classifier should be exchangable with a custom one. https://captum.ai/api/concept.html#captum.concept.Classifier https://github.com/pytorch/captum/issues/930

ChatGPT recommands the following:

class LinearClassifier(nn.Module, Classifier): def init(self, input_dim: int, output_dim: int,device: str ='cpu') -> None: super().init() self.device = device self.linear = nn.Linear(input_dim, output_dim).to(self.device)

def forward(self, x):
    return self.linear(x)

def train_and_eval(
        self, dataloader: DataLoader, **kwargs: Any
) -> Union[Dict, None]:

    optimizer = optim.SGD(self.parameters(), lr=0.01)
    criterion = nn.CrossEntropyLoss()
    num_epochs = 10

    for epoch in range(num_epochs):
        self.train()
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
            optimizer.zero_grad()
            outputs = self(x_batch)
            loss = criterion(outputs, y_batch)
            loss.backward()
            optimizer.step()

    # Evaluation after training
    self.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x_batch, y_batch in dataloader:
            x_batch, y_batch = x_batch.to(self.device), y_batch.to(self.device)
            outputs = self(x_batch)
            _, predicted = torch.max(outputs.data, 1)
            total += y_batch.size(0)
            correct += (predicted == y_batch).sum().item()

    accuracy = correct / total
    print(f'Accuracy: {accuracy:.2f}')

    return {'accs': accuracy}

def weights(self) -> torch.Tensor:
    return self.linear.weight.t()  # Transpose to match C x F format

def classes(self) -> List[int]:
    return list(range(self.linear.out_features))

The first issue is that now for a nn.Linear instead of a SkLearnSGDClassifier - linear classifier the classifier need now input and output dimensions - Dimension that assumingly come from the CAVs.

Alright, then I overwrite the captum methods custom_train_cav() so that I can instanciate the custom classifier there with the derrived dimensions:

dataloader = DataLoader(labelled_dataset, collate_fn=batch_collate) exampleinput, = next(iter(dataloader)) input_dim = example_input.size(1) # Assuming the second dimension is the feature size print("input_dim", input_dim) output_dim = len(concepts) print("output_dim", input_dim)

    classifier = classifier(input_dim=input_dim, output_dim=output_dim, **classifier_kwargs)

But now it crashes at

loss = criterion(outputs, y_batch)

Error message: "IndexError: Target 44 is out of bounds."

For some reason the DefaultClassifier can handle the very same target. (see my other question: Captum TCAV train.py sgd_train_linear_model() Where do the weights come from?)

Please tell me I'm over complicating the set-up. I don't understand how now one else has this problem, when even captum recommands to adapt the classifier.

I tried implementing the class LinearClassifier() seen above. I looked through the captum-tcav-forks to look what other people did, but they don't seem to far away from the boilerplate-framework. I looked through the github issues.

Edited to clarify more details.