nasa / pretrained-microscopy-models

MIT License
79 stars 19 forks source link

How to segment an unlabeled image, colab example #3

Closed sgbaird closed 2 years ago

sgbaird commented 2 years ago

Could you help me fill in the blank here?

https://colab.research.google.com/drive/1eGvA7_q1dgIEZjHJvugZ0aa-koY3Vg_c?usp=sharing

JStuckner commented 2 years ago

You should follow this example:

The basic things you will need to do are:

  1. Label a few images or image patches. I use Gimp. Make a second layer for the label and draw in red over the grain boundaries.
  2. Then you'll need to upload the images into train, train_annot, val, val_annot folders. You'll need at least 1 train and 1 validation image, but I'd guess you'll want at least a couple more training images for decent results. If you name the images "image1.tif" and "image1_mask.tif" you can just point to the directories when making the Datasets (see example). Otherwise you'll need to provide a list of the file paths for train and validation datasets. Don't worry about a test set.
  3. Make sure you set class_values = {'grain_boundary': [255,0,0]} (if you draw the boundary in red).
  4. From there you should be able to follow along with the example. You'll probably need a GPU, but maybe not.

Hope this helps!

sgbaird commented 2 years ago

Thanks! What about for prediction of an unlabeled image?

In https://colab.research.google.com/drive/1eGvA7_q1dgIEZjHJvugZ0aa-koY3Vg_c?usp=sharing, I'm getting:

import numpy as np
import torch
arr = torch.from_numpy(np.asarray(img).transpose(2, 0, 1)).unsqueeze(0).float().cpu()
print(arr.shape)
model.predict(arr)
torch.Size([1, 3, 224, 300])
---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[<ipython-input-20-ba3484a48b33>](https://localhost:8080/#) in <module>()
      3 arr = torch.from_numpy(np.asarray(img).transpose(2, 0, 1)).unsqueeze(0).float().cpu()
      4 print(arr.shape)
----> 5 model.predict(arr)

5 frames
[/usr/local/lib/python3.7/dist-packages/segmentation_models_pytorch/base/model.py](https://localhost:8080/#) in predict(self, x)
     38 
     39         with torch.no_grad():
---> 40             x = self.forward(x)
     41 
     42         return x

[/usr/local/lib/python3.7/dist-packages/segmentation_models_pytorch/base/model.py](https://localhost:8080/#) in forward(self, x)
     14         """Sequentially pass `x` trough model`s encoder, decoder and heads"""
     15         features = self.encoder(x)
---> 16         decoder_output = self.decoder(*features)
     17 
     18         masks = self.segmentation_head(decoder_output)

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/segmentation_models_pytorch/unet/decoder.py](https://localhost:8080/#) in forward(self, *features)
    117         for i, decoder_block in enumerate(self.blocks):
    118             skip = skips[i] if i < len(skips) else None
--> 119             x = decoder_block(x, skip)
    120 
    121         return x

[/usr/local/lib/python3.7/dist-packages/torch/nn/modules/module.py](https://localhost:8080/#) in _call_impl(self, *input, **kwargs)
   1108         if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks or _global_backward_hooks
   1109                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1110             return forward_call(*input, **kwargs)
   1111         # Do not call functions when jit is used
   1112         full_backward_hooks, non_full_backward_hooks = [], []

[/usr/local/lib/python3.7/dist-packages/segmentation_models_pytorch/unet/decoder.py](https://localhost:8080/#) in forward(self, x, skip)
     36         x = F.interpolate(x, scale_factor=2, mode="nearest")
     37         if skip is not None:
---> 38             x = torch.cat([x, skip], dim=1)
     39             x = self.attention1(x)
     40         x = self.conv1(x)

RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 20 but got size 19 for tensor number 1 in the list.
JStuckner commented 2 years ago

I think the image size has to be divisible by 32 when using UNet with 5 blocks. Try the built in function and set the patch size to 224. I was able to get it to work in your notebook with the below code. You'll want to train the model on labelled data for it to be accurate though. Then you can get the preprocessing_fn when you load the model like in the example.

import segmentation_models_pytorch as smp
preprocessing_fn = smp.encoders.get_preprocessing_fn('resnet50', 'imagenet')
img = np.asarray(img)
pred = pmm.segmentation_training.segmentation_models_inference(img, model, preprocessing_fn, batch_size=4, patch_size=224, device='cpu', probabilities=None)
import matplotlib.pyplot as plt
fig, (ax0, ax1) = plt.subplots(1, 2, figsize=(16,16))
ax0.imshow(img)
ax0.set_title('Original Image')
ax1.imshow(pred[...,0])
ax1.set_title('output')
sgbaird commented 2 years ago

@JStuckner thanks! I added this and was able to reproduce as you mentioned: image

(lhs image By Edward Pleshakov - Own work, CC BY 3.0, https://commons.wikimedia.org/w/index.php?curid=3912586)

In terms of using some training and validation data specific to GBs, https://www.doitpoms.ac.uk/miclib/browse.php?cat=1&list=mic&page=1 + some manual segmentation might be a good source.

Related article on grain segmentation: https://doi.org/10.1016/j.matchar.2021.110978 (code @ https://data.mendeley.com/datasets/t4wvpy29fz/4)

A related classification task since I noticed there was a notebook about classification: https://dx.doi.org/10.1016/j.isci.2022.103774

linjiangya commented 9 months ago

You should follow this example:

The basic things you will need to do are:

  1. Label a few images or image patches. I use Gimp. Make a second layer for the label and draw in red over the grain boundaries.
  2. Then you'll need to upload the images into train, train_annot, val, val_annot folders. You'll need at least 1 train and 1 validation image, but I'd guess you'll want at least a couple more training images for decent results. If you name the images "image1.tif" and "image1_mask.tif" you can just point to the directories when making the Datasets (see example). Otherwise you'll need to provide a list of the file paths for train and validation datasets. Don't worry about a test set.
  3. Make sure you set class_values = {'grain_boundary': [255,0,0]} (if you draw the boundary in red).
  4. From there you should be able to follow along with the example. You'll probably need a GPU, but maybe not.

Hope this helps!

Dear author, thank you for such detailed instructions.

1) Does this mean that we need at least one labeled image to perform the segmentation model and use the model trained on the labeled images for inference?

2) Am I understanding correctly or missing somthing?: It seems like this repo only released the encoder weights for MicroNet, but did not release the model for the whole segmentation model? If correct, do you have a plan to release the pre-trained model in your paper for segmentation?

Best regards,