Open julesmorata opened 11 months ago
I also encountered this issue. It arises from the rescaling line in the "forward" method of the DecoderBlock class in unet.py. I modified the code to interpolate directly to the correct shape as:
if self.scale_factor != 1.0: if skip is not None: target_size = (skip.shape[2],skip.shape[3]) x = F.interpolate(x, size = target_size, mode='nearest') else: x = F.interpolate(x,scale_factor=self.scale_factor,mode = 'nearest')
This ensures that the concatenation operation will proceed as desired. Not pushing this until I've tested that it doesn't mess anything up.
One should probably modify how the target_size tuple is defined by indexing from the back of skip.shape.
Hello,
I am trying to use your framework to perform mask extraction on images I have with the following code :
For the moment I just print masks to check what they look like before saving them. But I get the following error when trying to run the code :
Traceback (most recent call last): File "hidden/mask.py", line 19, in
mask = model.predict(tensor)
File "env/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 160, in predict
x = self.forward(x)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 142, in forward
x = self.decoder(x)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(args, kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 304, in forward
x = b(x, skip)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, *kwargs)
File "env/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(args, **kwargs)
File "env/lib/python3.10/site-packages/backbones_unet/model/unet.py", line 246, in forward
x = torch.cat([x, skip], dim=1)
RuntimeError: Sizes of tensors must match except in dimension 1. Expected size 14 but got size 13 for tensor number 1 in the list.
And when printing dimensions of my input tensor as well as x and skip which are concatenated, I get in the same order :
torch.Size([1, 3, 207, 204]) torch.Size([1, 2048, 14, 14]) torch.Size([1, 1024, 13, 13])
Would you know where this come from / how to fix it ?