fepegar / unet

"pip install unet": PyTorch Implementation of 1D, 2D and 3D U-Net architecture.
MIT License
148 stars 22 forks source link

Not all input dimensions work out #29

Closed prhbrt closed 2 years ago

prhbrt commented 3 years ago

Please document how to train and predict, particularly regarding input dimensions.

Minimal failing example

import torch
from unet import UNet2D

device = 'cuda:0'
spatial_dimension = 511
n_samples = 2
n_channels = 1

X = torch.zeros((n_samples, n_channels, spatial_dimension, spatial_dimension),
                dtype=torch.float32, device=device)

model = UNet2D(in_channels=n_channels, residual=True).to(device)
model.forward(X)

Error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-27-cb3f6be760ae> in <module>
     11 
     12 model = UNet2D(in_channels=n_channels, residual=True).to(device)
---> 13 model.forward(X)

unet/unet.py in forward(self, x)
    120         skip_connections, encoding = self.encoder(x)
    121         encoding = self.bottom_block(encoding)
--> 122         x = self.decoder(skip_connections, encoding)
    123         if self.monte_carlo_layer is not None:
    124             x = self.monte_carlo_layer(x)

torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

unet/decoding.py in forward(self, skip_connections, x)
     59         zipped = zip(reversed(skip_connections), self.decoding_blocks)
     60         for skip_connection, decoding_block in zipped:
---> 61             x = decoding_block(skip_connection, x)
     62         return x
     63 

torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    725             result = self._slow_forward(*input, **kwargs)
    726         else:
--> 727             result = self.forward(*input, **kwargs)
    728         for hook in itertools.chain(
    729                 _global_forward_hooks.values(),

unet/decoding.py in forward(self, skip_connection, x)
    129         x = self.upsample(x)
    130         skip_connection = self.center_crop(skip_connection, x)
--> 131         x = torch.cat((skip_connection, x), dim=CHANNELS_DIMENSION)
    132         if self.residual:
    133             connection = self.conv_residual(x)

RuntimeError: Sizes of tensors must match except in dimension 2. Got 62 and 63 (The offending index is 0)
fepegar commented 2 years ago

Hi, @prhbrt. Sorry I'm late. See #28 for more info about this issue.