torch / nn

Other
1.34k stars 968 forks source link

Inconsistence in SpatialMaxPooling and SpatialMaxUnpooling operations for torch.CudaTensor that leads to an error #1287

Open borismu opened 6 years ago

borismu commented 6 years ago

Using 'torch.FloatTensor' as dtype in following code snippet works just fine. Pool layer produces indices array of consistent size (1x5x5) with the input of unpooling layer.

dtype = 'torch.FloatTensor'
model = nn.Sequential()
layer = nn.SpatialMaxPooling(2,2,2,2)
model:add(layer)
model:add(nn.SpatialMaxUnpooling(layer))
model:type(dtype)

x = torch.randn(1,10,10):type(dtype)
model:forward(x)

However using 'torch.CudaTensor' produces error on executing model:forward(x).

dtype = 'torch.CudaTensor'
model = nn.Sequential()
layer = nn.SpatialMaxPooling(2,2,2,2)
model:add(layer)
model:add(nn.SpatialMaxUnpooling(layer))
model:type(dtype)

x = torch.randn(1,10,10):type(dtype)
model:forward(x)

Pooling layer instead produces indices of size 1x1x5x5 which is inconsistent with input of unpooling layer. That leads to following error. torch/install/share/lua/5.1/nn/THNN.lua:110: indices and input shapes do not match: indices [1 x 1 x 5 x 5], input [1 x 5 x 5] at /tmp/luarocks_cunn-scm-1-3042/cunn/lib/THCUNN/generic/SpatialMaxUnpooling.cu:15