Open Mishalfatima opened 1 month ago
Hi!
in my experiments, I use pytorch-lighning and wrap the backbones into a pl.LightningModule
class. Hence, I could define the method directly in the pl.LightningModule
class. I would suggest you try the following:
import torch as th
import lightning as pl
import torch.nn as nn
class Classifier(pl.LightningModule):
def __init__(self, model: nn.Module):
super().__init__()
self.model = model
def training_step(self, batch, batch_idx):
pass # your training code
def validation_step(self, batch, batch_idx):
pass # your validation code
def log_odds(self, X, *, c):
logits = self.model(X)
log_probs = logits.log_softmax(axis=1)
mask = th.ones_like(log_probs[:1])
mask[:, c] = 0
log_odds = log_probs[:, c] - th.logsumexp(log_probs * mask, dim=1)
return log_probs[:, c].exp(), log_odds
Alternatively, you can also subclass from nn.Module
directly:
class Classifier(nn.Module):
# the rest is the same, but you don't have to define training_step and validation_step, but the forward method
I hope this helps you.
Thanks for the response! After implementing the changes you suggested, code seems to be running fine, but I run into another error at the following line: (
loss +
# regularization
reg_scale(i, config.reg_warmup) * (config.l1 * l1_norm + config.tv * tvl)
).backward()
The error is:
Traceback (most recent call last):
File "/work/mfatima/pytorch-image-models/DFR_validate.py", line 616, in
The forward pass of model happens under torch.no_grad().
You need gradients to estimate the FIDO maps, meaning you should call fido.fit(...)
outside of the torch.no_grad() context.
Thanks for the answer! Lastly, update_callback() is none everywhere, but you have called it inside the fit function which ends up giving the following error:
Traceback (most recent call last):
File "/work/mfatima/pytorch-image-models/DFR_validate.py", line 616, in
Oh, I forgot an "is not None" check right there... I've just updated the code, but I don't have the permissions to merge the pull request right now. You can check out the latest code in my fork of this repo: https://github.com/dikorsch/fido-pytorch/.
Thanks!
After fitting FIDO, which function returns the join mask as proposed in your work?
There are properties for the dropout rates: ssr_dropout_rate
, sdr_dropout_rate
, joint_dropout_rate
. But when you need the "keep mask", then check out the fido.plot(...)
method (https://github.com/cvjena/fido-pytorch/blob/main/src/fido/module.py#L188-L196). It's a bit complicated because we are dealing with counter-probabilities.
In the paper, you mention that you use estimated attribution maps to enhance performance. So, basically, you use predicted joint masks to perturb input image and estimate classification performance?
No, we used the attribution masks to estimate the regions that are most important to the classification and use these regions to extract an additional patch from the image, which is then used in the classification model as additional input. Then, we combined predictions and this enhanced the classification performance.
Hi! I am facing the following error while running your code. I am using hugging face pytorch models. The function "log_odds" is unrecognizable.
Traceback (most recent call last): File "/work/mfatima/pytorch-image-models/DFR_validate.py", line 619, in
main()
File "/work/mfatima/pytorch-image-models/DFR_validate.py", line 592, in main
results = validate(args)
File "/work/mfatima/pytorch-image-models/DFR_validate.py", line 427,
in validate
fido.fit(im, predicted_class, model, config=fido_config)
File "/work/mfatima/pytorch-image-models/fido_pytorch/src/fido/module.py",
line 144, in fit
masks, probs, loss, l1_norm, tvl = self.objective(im, y, clf,
batch_size=config.batch_size)
File "/work/mfatima/pytorch-image-models/fido_pytorch/src/fido/module.py",
line 127, in objective
prob, odds = clf.log_odds(_x, c=y)
File "/work/mfatima/miniconda3/envs/TvS/lib/python3.9/site-packages/torch/nn/modules/module.py",
line 1688, in getattr
raise AttributeError(f"'{type(self).name}' object has no
attribute '{name}'")
AttributeError: 'ConvNeXt' object has no attribute 'log_odds'