Closed godaup closed 3 years ago
This is how I would go:
def fim_function(*d):
log_softmax = torch.log_softmax(function(*d), dim=1)
s_mb, s_c, s_h, s_w = log_softmax.size()
log_softmax = log_softmax.permute(0, 2, 3, 1).contiguous().view(s_mb * s_h * s_w, s_c)
probabilities = torch.exp(log_softmax)
sampled_indices = torch.multinomial(probabilities, trials,
replacement=True)
sampled_targets = torch.gather(log_softmax, 1,
sampled_indices)
sampled_targets = sampled_targets.view(s_mb, s_h * s_w, trials).sum(dim=1)
return trials ** -.5 * sampled_targets
But I don't have a simple setup to test it here. If you can come up with the simplest model to test it, like a 2 layer MLP on MNIST, I will be able to actually test it and add it to NNGeometry.
Below you find a minimalist setup. Hope this fits and thanks for the fast reply! :)
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.utils.data
import torchvision
from tqdm import tqdm
data_path = 'path/to/mnist/data'
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
# model
class MLP(nn.Module):
def __init__(self):
super(MLP, self).__init__()
self.hidden = nn.Linear(in_features=28*28, out_features=28*28*3)
self.final = nn.Linear(in_features=28*28*3, out_features=28*28*10)
def forward(self, x):
x = x.reshape(-1, 28*28)
x = F.relu(self.hidden(x))
x = F.relu(self.final(x))
x = x.reshape(-1, 10, 28, 28)
return x
model = MLP()
criterion = torch.nn.BCEWithLogitsLoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.001)
# data
class MNIST_Seg(torch.utils.data.Dataset):
def __init__(self, train=True):
super(MNIST_Seg, self).__init__()
transforms = torchvision.transforms.Compose([
torchvision.transforms.ToTensor(),
torchvision.transforms.Normalize(0.5, 0.5)
])
self.mnist = torchvision.datasets.MNIST(root=data_path, train=train, transform=transforms, download=True)
def __len__(self):
return len(self.mnist)
def __getitem__(self, item):
image, label = self.mnist.__getitem__(item)
mask = torch.zeros(10, 28, 28)
mask[label] = image.gt(0).to(torch.float32)
return image, mask
train_loader = torch.utils.data.DataLoader(MNIST_Seg(True), batch_size=20, shuffle=True)
val_loader = torch.utils.data.DataLoader(MNIST_Seg(False), batch_size=20, shuffle=False)
# training
model.to(device)
model.train()
for ep in range(10):
with tqdm(total=len(train_loader), desc=f'epoch {ep + 1}') as t:
for images, masks in train_loader:
images, masks = images.to(device), masks.to(device)
optimizer.zero_grad()
loss = criterion(model(images), masks)
loss.backward()
optimizer.step()
t.set_postfix(loss=loss.item())
t.update()
it is now part of the FIM_MonteCarlo
helper.
I am closing the issue.
Suppose my model outputs Batch x Class x Height x Width tensor for a multiclass image segmentation task (as in here) and I want to compute the Monte Carlo FIM. Shouldn't it be possible to interpret this as a Height x Width - fold classification problem and easily adapt the fim_function?