MASILab / 3DUX-Net

240 stars 33 forks source link

The model of params and flops #1

Closed Yffy123456 closed 1 year ago

Yffy123456 commented 1 year ago

I have had the honor to read your paper. I have some questions and hope to get your answers

if name == 'main': with torch.no_grad(): import os os.environ['CUDA_VISIBLE_DEVICES'] = '0' x = torch.rand(4, 1, 96, 96, 96) model = UXNET() y = model(x)

print(y.shape)

    # flops, params = profile(model, (x,))
    # print('flops: ', flops, 'params: ', params)
    device = torch.device('cpu')
    model.to(device)
    flops, params = profile(model.to(device), inputs=(x,))
    flops, params = clever_format([flops, params])
    print(flops, params)

----params=58.79M, flops =2.66T The flops is much different from yours (969696 I don't know if the changes in the image are correct) Looking forward to your reply

Methods resolution #params FLOPs Mean Dice (AMOS2022) TransBTS 96x96x96 31.6M 110.4G 0.792 UNETR 96x96x96 92.8M 82.6G 0.762 nnFormer 96x96x96 149.3M 240.2G 0.790 SwinUNETR 96x96x96 62.2M 328.4M 0.880 3D UX-Net 96x96x96 53.0M 639.4G

leeh43 commented 1 year ago

Hi, really thank you for your interest towards our network! Seems like you have a batch size of 4, so your input to calculate flops is (4, 1, 96, 96, 96). The inputs that we calculate for flops is just (1, ,1 96, 96, 96). Here is the way that I compute flops and feel free to take a look:

from ptflops import get_model_complexity_info
with torch.cuda.device(0):
    macs, params = get_model_complexity_info(model, (1, 96, 96, 96), as_strings=True,
                                           print_per_layer_stat=True, verbose=True)

    print('{:<30}  {:<8}'.format('Computational complexity: ', macs))
    print('{:<30}  {:<8}'.format('Number of parameters: ', params))

Also, I just found that my default kernel size for depth convolution in the uxnet block is 13 with padding 6. Please change it to the kernel size of 7 with padding 3. Then, the model parameters will be 53.0M