aleximmer / Laplace

Laplace approximations for Deep Learning.
https://aleximmer.github.io/Laplace
MIT License
436 stars 63 forks source link

Help for Running Laplace on Image Segmentation Tasks #111

Open SouLeo opened 1 year ago

SouLeo commented 1 year ago

Hello,

I am using a U-Net augmentation (specifically: https://github.com/juntang-zhuang/LadderNet) to perform segmentation of hands. To be specific, I am classifying each pixel of an image to one of five classes (no hand, my right hand, my left hand, your right hand, your left hand.)

This requires my prob shape (in fisher.py: 446) to be [batch_size, img_h, img_w, n_classes] -> [8,32,32,5] (snippet below)

def __fisher_exact(loss_and_backward, model, probs):
    _, n_classes = probs.shape  

Because of this dimensionality, this line of code fails. I assume it's because it expects the probs.shape tuple to be (img_as_tensor, label_as_int) per the CIFAR example: https://github.com/AlexImmer/Laplace/blob/main/examples/calibration_example.py) where the CIFAR dataset object returns a a tuple of (#_examples, (img_as_tensor: [3,32,32], label_as_int: 0,1,2,3,etc.).

I can always reshape my training data to be a tuple of that format, but because I am classifying by pixel, my associated label would not be a single integer. It would have to be in the same shape as the image tensor [32,32].

So I'm asking this question to the community to see if anyone has attempted this kind of segmentation task using the laplace-torch package before I try to force a solution within the asdfghjkl/fisher.py file.

SouLeo commented 1 year ago

Update: I have gotten the following lines of code to run by modifying my LadderNet model to follow this architecture:

LadderNetv6(
  (initial_block): Initial_LadderBlock(
    (inconv): Conv2d(1, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (down_module_list): ModuleList(
      (0): BasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (1): BasicBlock(
        (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (2): BasicBlock(
        (conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (3): BasicBlock(
        (conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
    )
    (down_conv_list): ModuleList(
      (0): Conv2d(10, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (1): Conv2d(20, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (2): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      (3): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
    )
    (bottom): BasicBlock(
      (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
      (relu): ReLU()
      (drop): Dropout2d(p=0.25, inplace=False)
    )
    (up_conv_list): ModuleList(
      (0): ConvTranspose2d(160, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (1): ConvTranspose2d(80, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (2): ConvTranspose2d(40, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      (3): ConvTranspose2d(20, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
    )
    (up_dense_list): ModuleList(
      (0): BasicBlock(
        (conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (1): BasicBlock(
        (conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (2): BasicBlock(
        (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (3): BasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
    )
  )
  (final_block): Final_LadderBlock(
    (block): LadderBlock(
      (inconv): BasicBlock(
        (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (down_module_list): ModuleList(
        (0): BasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (1): BasicBlock(
          (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (2): BasicBlock(
          (conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (3): BasicBlock(
          (conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
      )
      (down_conv_list): ModuleList(
        (0): Conv2d(10, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (1): Conv2d(20, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (2): Conv2d(40, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
        (3): Conv2d(80, 160, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1))
      )
      (bottom): BasicBlock(
        (conv1): Conv2d(160, 160, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
        (relu): ReLU()
        (drop): Dropout2d(p=0.25, inplace=False)
      )
      (up_conv_list): ModuleList(
        (0): ConvTranspose2d(160, 80, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
        (1): ConvTranspose2d(80, 40, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
        (2): ConvTranspose2d(40, 20, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
        (3): ConvTranspose2d(20, 10, kernel_size=(3, 3), stride=(2, 2), padding=(1, 1), output_padding=(1, 1))
      )
      (up_dense_list): ModuleList(
        (0): BasicBlock(
          (conv1): Conv2d(80, 80, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (1): BasicBlock(
          (conv1): Conv2d(40, 40, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (2): BasicBlock(
          (conv1): Conv2d(20, 20, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
        (3): BasicBlock(
          (conv1): Conv2d(10, 10, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
          (relu): ReLU()
          (drop): Dropout2d(p=0.25, inplace=False)
        )
      )
    )
  )
  (final_fc): Final_Layer(
    (layer): Linear(in_features=10, out_features=5, bias=False)
  )
)

la = Laplace(net, 'classification', subset_of_weights='last_layer', hessian_structure='kron', backend=AsdlGGN)

la.fit(train_loader)

la.optimize_prior_precision(method='marglik')

However, despite successfully training my model, running predictions on my model, running .fit() for post-hoc laplace on my model, I cannot run the following code without error:


@torch.no_grad()

def predict(dataloader, model, laplace=False):
"""
this code was taken from the calibration_example.py
"""

    py = []

    for x, _ in dataloader:

        if laplace:

           py.append(model(x.cuda()))

        else:

            py.append(torch.softmax(model(x.cuda()), dim=-1))

    return torch.cat(py).cpu

probs_laplace = predict(test_loader, la, laplace=True)  # this line fails

The following is the trace when running the predict() code:

RuntimeError                              Traceback (most recent call last)
Input In [12], in <cell line: 65>()
     50 # # TODO: specify val_loader
     51 # # From API docs page
     52 # # post-hoc update:
   (...)
     61 
     62 # From GitHub CIFAR example:
     63 la.optimize_prior_precision(method='marglik') #, val_loader=test_loader_copy)
---> 65 probs_laplace = predict(test_loader_copy, la, laplace=True) # in future, replace w/test set: test_loader
     67 acc_laplace = (probs_laplace.argmax(-1) == targets).float().mean()
     69 # ece_laplace = ECE(bins=15).measure(probs_laplace.numpy(), targets.numpy())
     70 
     71 # nll_laplace = -dists.Categorical(probs_laplace).log_prob(targets).mean()
     72 
     73 
     74 # print(f'[Laplace] Acc.: {acc_laplace:.1%}; ECE: {ece_laplace:.1%}; NLL: {nll_laplace:.3}')')

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/torch/autograd/grad_mode.py:27, in _DecoratorContextManager.__call__.<locals>.decorate_context(*args, **kwargs)
     24 @functools.wraps(func)
     25 def decorate_context(*args, **kwargs):
     26     with self.clone():
---> 27         return func(*args, **kwargs)

Input In [12], in predict(dataloader, model, laplace)
     12 print(x.shape)
     13 if laplace:
---> 14     py.append(model(x.cuda()))
     15 else:
     16     py.append(torch.softmax(model(x.cuda()), dim=-1))

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/baselaplace.py:536, in ParametricLaplace.__call__(self, x, pred_type, link_approx, n_samples)
    533     raise ValueError(f'Unsupported link approximation {link_approx}.')
    535 if pred_type == 'glm':
--> 536     f_mu, f_var = self._glm_predictive_distribution(x)
    537     # regression
    538     if self.likelihood == 'regression':

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/lllaplace.py:124, in LLLaplace._glm_predictive_distribution(self, X)
    122 print(Js.shape)
    123 print(f_mu.shape)
--> 124 f_var = self.functional_variance(Js)
    125 print('shape of f_var, which is variance(Js)')
    126 print(f_var.shape)

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/baselaplace.py:841, in KronLaplace.functional_variance(self, Js)
    840 def functional_variance(self, Js):
--> 841     return self.posterior_precision.inv_square_form(Js)

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/utils/matrix.py:411, in KronDecomposed.inv_square_form(self, W)
    409 print('from laplace/utils inv_square_form')
    410 print(W.shape)
--> 411 SW = self._bmm(W, exponent=-1)
    412 return torch.bmm(W, SW.transpose(1, 2))

File ~/.conda/envs/hrc-laddernet/lib/python3.10/site-packages/laplace/utils/matrix.py:404, in KronDecomposed._bmm(self, W, exponent)
    402 print('length of SW')
    403 print(len(SW))
--> 404 SW = torch.cat(SW, dim=1).reshape(B, K, P)
    405 return SW

RuntimeError: shape '[1024, 32, 320]' is invalid for input of size 1638400

I'm somewhat at a loss because I did not expect this model to fail if the .fit() function performed properly. Any help would be greatly appreciated.

wiseodd commented 1 year ago

Hi @SouLeo, multi-output models are indeed still in our backlog. For now, I think this paper https://arxiv.org/abs/2206.15078 along with the code can be very useful for you https://github.com/FrederikWarburg/LaplaceAE.

SouLeo commented 1 year ago

Oh, I see. So this framework is not suited for multiclass labels for a single image?

I'll review the items you have linked. Thank you very much!

I am still somewhat confused that I was able to perform the following lines of code without error:

la = Laplace(net, 'classification', subset_of_weights='last_layer', hessian_structure='kron', backend=AsdlGGN)

la.fit(train_loader)

la.optimize_prior_precision(method='marglik')

but cannot run the model prediction. Do you have any thoughts on this?

jerofad commented 1 year ago

@SouLeo Were you able to use this library successfully for image segmentation?