mattmacy / vnet.pytorch

A PyTorch implementation for V-Net: Fully Convolutional Neural Networks for Volumetric Medical Image Segmentation
https://mattmacy.github.io/vnet.pytorch
BSD 3-Clause "New" or "Revised" License
688 stars 201 forks source link

Concatenation operation in InputTransition #1

Closed RongzhaoZhang closed 7 years ago

RongzhaoZhang commented 7 years ago

If my understanding is right, the concatenation in the InputTransition block should be applied along dim=1 instead of dim=0, because the second dimension is channel. i.e.

# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x,
                 x, x, x, x, x, x, x, x), 0)

should be

# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x,
                 x, x, x, x, x, x, x, x), 1)
mattmacy commented 7 years ago

Sorry, for the delay. The dimensions are BatchSize, Channels, Z, Y, X. The point is to create 16 channels, not to increase the batch size by 16x.

Cassieyy commented 3 years ago

If my understanding is right, the concatenation in the InputTransition block should be applied along dim=1 instead of dim=0, because the second dimension is channel. i.e.

# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x,
                 x, x, x, x, x, x, x, x), 0)

should be

# split input in to 16 channels
x16 = torch.cat((x, x, x, x, x, x, x, x,
                 x, x, x, x, x, x, x, x), 1)

In the meanwhile, the x's channel number is changed via conv1, so that it needs to save the original input(whose channel number is 1), namely

split input in to 16 channels

x16 = torch.cat((input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x, input_x), 1)

PussyCat0700 commented 3 years ago

Sorry, for the delay. The dimensions are BatchSize, Channels, Z, Y, X. The point is to create 16 channels, not to increase the batch size by 16x.

Thx very much for the apply!