mberkay0 / pretrained-backbones-unet

A PyTorch-based Python library with UNet architecture and multiple backbones for Image Semantic Segmentation.
MIT License
46 stars 9 forks source link

Dimension issues with predict method #8

Open julesmorata opened 8 months ago

julesmorata commented 8 months ago

Hello,

I am trying to use your framework to perform mask extraction on images I have with the following code :

import os

import torchvision.transforms as transforms
from backbones_unet.model.unet import Unet
from PIL import Image

PATH="path_to_folder" # Replaced for confidentiality reasons 
CELLS_PATH = PATH + "cells/"

# Model init
model = Unet(backbone='xception71', in_channels=3, num_classes=1)
transform = transforms.Compose([transforms.ToTensor()])

# Data processing
for filename in os.listdir(CELLS_PATH):
    image = Image.open(CELLS_PATH + filename)
    tensor = transform(image).unsqueeze(0)
    print(tensor.shape)
    mask = model.predict(tensor)
    print(mask)

For the moment I just print masks to check what they look like before saving them. But I get the following error when trying to run the code :

Traceback (most recent call last): File "hidden/mask.py", line 19, in mask = model.predict(tensor) File "env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context return func(*args, kwargs) File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 160, in predict x = self.forward(x) File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 142, in forward x = self.decoder(x) File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 304, in forward x = b(x, skip) File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, **kwargs) File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 246, in forward x = torch.cat([x, skip], dim=1) RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 14 but got size 13 for tensor number 1 in the list.

And when printing dimensions of my input tensor as well as x and skip which are concatenated, I get in the same order :

torch.Size([1, 3, 207, 204]) torch.Size([1, 2048, 14, 14]) torch.Size([1, 1024, 13, 13])

Would you know where this come from / how to fix it ?

physgorg commented 6 months ago

I also encountered this issue. It arises from the rescaling line in the "forward" method of the DecoderBlock class in unet.py. I modified the code to interpolate directly to the correct shape as: if self.scale_factor != 1.0: if skip is not None: target_size = (skip.shape[2],skip.shape[3]) x = F.interpolate(x, size = target_size, mode='nearest') else: x = F.interpolate(x,scale_factor=self.scale_factor,mode = 'nearest')

This ensures that the concatenation operation will proceed as desired. Not pushing this until I've tested that it doesn't mess anything up.

physgorg commented 6 months ago

One should probably modify how the target_size tuple is defined by indexing from the back of skip.shape.