MrBlankness / LightM-UNet

Pytorch implementation of "LightM-UNet: Mamba Assists in Lightweight UNet for Medical Image Segmentation"
https://arxiv.org/abs/2403.05246
Apache License 2.0
262 stars 23 forks source link

Where is DWConv used in the LightM-UNet model? #6

Open HashmatShadab opened 6 months ago

HashmatShadab commented 6 months ago

https://github.com/MrBlankness/LightM-UNet/blob/b484335e6d76c3b6de3c1813185ee6182d633c73/lightm-unet/nnunetv2/nets/LightMUNet.py#L17-L24

Can you please point out where is the DWConv used in constructing the model. I wasn't able to find where this is used in the model.

HashmatShadab commented 6 months ago

Running the model with model = LightMUNet( spatial_dims=3, init_filters=32, in_channels=1, out_channels=14, blocks_down=[1, 2, 2, 2], blocks_up=[1, 1, 1], ) gives around 6 million parameters

eclipse0922 commented 6 months ago

ResUpBlock uses get_dwconv_layer but LightMUNet does not use ResUpBlock in the code. Currently, LightMUNet uses ResBlock.

HashmatShadab commented 6 months ago

So the current model does not reflect the one mentioned in the paper? Even the initial and final layers don't have DWConv. Even with these changes, the model size is 2.8 million for 3d model.

eclipse0922 commented 6 months ago

Yes, it looks like the author probably uploaded the code used in the test by mistake.

HashmatShadab commented 5 months ago

Response from the author regarding this issue will be appreciated.

MrBlankness commented 5 months ago

https://github.com/MrBlankness/LightM-UNet/blob/b484335e6d76c3b6de3c1813185ee6182d633c73/lightm-unet/nnunetv2/nets/LightMUNet.py#L17-L24

Can you please point out where is the DWConv used in constructing the model. I wasn't able to find where this is used in the model.

Thank you very much for your attention to our work. We apologize for inadvertently uploading the code related to our ablation experiments in our previous submission, such as replacing DWConv with Conv, and ResUpBlock with ResBlock. We have now updated our code. We greatly appreciate the questions raised by eclipse0922 and HashmatShadab. We will continue to address any issues with LightM-UNet and strive to improve and update our work.

HashmatShadab commented 5 months ago

So when will the correct code be uploaded?

MrBlankness commented 5 months ago

So when will the correct code be uploaded?

We have updated our code.

MrBlankness commented 5 months ago

Running the model with model = LightMUNet( spatial_dims=3, init_filters=32, in_channels=1, out_channels=14, blocks_down=[1, 2, 2, 2], blocks_up=[1, 1, 1], ) gives around 6 million parameters

Allow us to kindly remind you. The number of parameters and computations in the network vary with the settings of network hyperparameters. For example, altering the out_channels of the model will change the number of convolutional kernels in the final layer, thus affecting the number of parameters.

HashmatShadab commented 5 months ago

model = LightMUNet( spatial_dims=3, init_filters=32, in_channels=1, out_channels=14, blocks_down=[1, 2, 2, 2], blocks_up=[1, 1, 1], ) are the above arguments correct for loading the 3d model discussed in the paper?

MrBlankness commented 5 months ago

model = LightMUNet( spatial_dims=3, init_filters=32, in_channels=1, out_channels=14, blocks_down=[1, 2, 2, 2], blocks_up=[1, 1, 1], ) are the above arguments correct for loading the 3d model discussed in the paper?

No, the relevant parameter settings for the LiTS dataset should be as follows: model = LightMUNet(spatial_dims=3, init_filters=32, in_channels=1, out_channels=3, blocks_down=[1, 2, 2, 4], blocks_up=[1, 1, 1], )

HashmatShadab commented 5 months ago

So the current model does not reflect the one mentioned in the paper? Even the initial and final layers don't have DWConv. Even with these changes, the model size is 2.8 million for 3d model.

Also please respond to this issue as well

MrBlankness commented 5 months ago

So the current model does not reflect the one mentioned in the paper? Even the initial and final layers don't have DWConv. Even with these changes, the model size is 2.8 million for 3d model.

Also please respond to this issue as well

Apologies, based solely on the current information provided, I'm unable to analyze the reason. If possible, please provide more information to help me understand your issue better.

HashmatShadab commented 5 months ago
model = LightMUNet(
    spatial_dims=3,
    init_filters=32,
    in_channels=input_channels,
    out_channels=num_classes,
    blocks_down=[1, 2, 2, 2],
    blocks_up=[1, 1, 1],
)

Using the the model class from your updated code i am getting 2.9 M parameters.

MrBlankness commented 5 months ago

how many input_channels and num_classes you have set?

HashmatShadab commented 5 months ago

Input channels: 1 Output channels: 14

MrBlankness commented 5 months ago

Input channels: 1 Output channels: 14

We're sorry, but we are unable to reproduce your issue. 🤦‍

HashmatShadab commented 5 months ago

Can you please provide a short script for this then? How many total parameters are you getting?

MrBlankness commented 5 months ago
from thop import profile
model = LightMUNet(
    spatial_dims = 3,
    init_filters = 32,
    in_channels=1,
    out_channels=14,
    blocks_down=[1, 2, 2, 2],
    blocks_up=[1, 1, 1],
).cuda()

data = torch.rand(1, 1, 256, 256, 256).cuda()

_, params = profile(model, inputs=(data, ))
print(params/1e6)

I received a total parameters of 0.4667M.

HashmatShadab commented 5 months ago

`

from thop import profile

model = LightMUNet(
    spatial_dims = 3,
    init_filters = 32,
    in_channels=1,
    out_channels=14,
    blocks_down=[1, 2, 2, 2],
    blocks_up=[1, 1, 1],
).cuda()

data = torch.rand(1, 1, 256, 256, 256).cuda()

model_total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total Model Parameters = {model_total_params:,}\n")

_, params = profile(model, inputs=(data, ))
print(params)

`

Total Model Parameters = 2,997,821

[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>. [INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv3d'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>. [INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>. [INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>. [INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv1d'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.pooling.MaxPool3d'>. [INFO] Register count_upsample() for <class 'torch.nn.modules.upsampling.Upsample'>. WARNING:root:mode trilinear is not implemented yet, take it a zero op WARNING:root:mode trilinear is not implemented yet, take it a zero op WARNING:root:mode trilinear is not implemented yet, take it a zero op 466729.0

MrBlankness commented 5 months ago

Perhaps the principles behind the two methods of parameter counting are different? Although I'm not sure of the reason.

HashmatShadab commented 5 months ago

Haven't used thop package before, so I am also a bit confused. Using from torchinfo import summary to calculate parameters also gives 2.9M parameters

MrBlankness commented 5 months ago

Anyway, thank you very much for your feedback, and we will follow up on this issue.