zhengchen1999 / DAT

PyTorch code for our ICCV 2023 paper "Dual Aggregation Transformer for Image Super-Resolution"
Apache License 2.0
350 stars 27 forks source link

关于DAT_light参数量的问题 #9

Open Amazinghaodeihao666 opened 10 months ago

Amazinghaodeihao666 commented 10 months ago

尊敬的作者您好: 您的工作对我来说很有启发性,是一篇非常棒的文章。我有一个疑问就是:使用thop在x4算出来的DAT_light参数量是878k,而您给出的是573k,但是对于DAT、DAT_2、DAT_s算出来的参数量都与您给的一致。希望能够得到您的回复,谢谢!

zhengchen1999 commented 10 months ago

Hi. Thanks for your interest in our work.

I get the Params by directly counting all the parameters in the model using the following code:

def print_model_parm_nums(model):
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fK' % (total / 1e3))

I haven't used the thop, but I guess it may be due to parameter sharing in the model (but I haven't analyzed it specifically).

Amazinghaodeihao666 commented 10 months ago

Hi: 谢谢您这么快地回复。我也看到了您计算参数量的代码,我用thop来计算参数量主要是为了额外获取模型的flops,因为我在您的代码中没有找到输出flops的相关语句。我想请教一下在您的模型中怎么输出flops。 再次感谢您的回复!

---- 回复的原邮件 ---- | 发件人 | Zheng @.> | | 日期 | 2023年09月08日 11:49 | | 收件人 | @.> | | 抄送至 | @.>@.> | | 主题 | Re: [zhengchen1999/DAT] 关于DAT_light参数量的问题 (Issue #9) |

Hi. Thanks for your interest in our work.

I get the Params by directly counting all the parameters in the model using the following code:

def print_model_parm_nums(model): total = sum([param.nelement() for param in model.parameters()]) print('Number of params: %.2fM' % (total / 1e6))

I haven't used the thop, but I guess it may be due to parameter sharing in the model (but I haven't analyzed it specifically).

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

zhengchen1999 commented 10 months ago

I calculate the FLOPs through fvcore. The code is:

from fvcore.nn import FlopCountAnalysis
flops = FlopCountAnalysis(net, input)
print("FLOPs: ", flops.total())
Amazinghaodeihao666 commented 10 months ago

Hi, 尊敬的作者您好,我又有thop计算了一下那三个大模型 DAT2, DATs和DAT和您提出的参数量完全一致,只有DATlight不同。这令我很困惑,如果存在模型参数共享的话,那些大模型理论上也会有一些差异。现在差异只出现在轻量级上,我不知道您对这个问题感不感兴趣。如果您有时间的话,麻烦您能看下您计算参数量的方法。谢谢!(我也尝试看了您计算参数量的方法,但是能力有限,所以还是看能不能麻烦您来解决我的困惑)

---- 回复的原邮件 ---- | 发件人 | Zheng @.> | | 日期 | 2023年09月08日 12:23 | | 收件人 | @.> | | 抄送至 | @.>@.> | | 主题 | Re: [zhengchen1999/DAT] 关于DAT_light参数量的问题 (Issue #9) |

I calculate the FLOPs through fvcore. The code is:

from fvcore.nn import FlopCountAnalysis flops = FlopCountAnalysis(net, input) print("FLOPs: ", flops.total())

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

zhengchen1999 commented 10 months ago

The shared parameters are just my guess. Actually, I don't actively share the model parameters in DAT. I try thop to calculate the parameters of DAT-light-x4. The code is as follows:

x = torch.randn(1, 3, 64, 64, device='cuda')
import thop
flops, params = thop.profile(net,inputs=(x,))
print(params)

The result:

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>.
[INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>.
[INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>.
[INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>.
[INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>.
[INFO] Register zero_ops() for <class 'torch.nn.modules.pixelshuffle.PixelShuffle'>.
572712.0

It is consistent with the results of the paper. Are there some errors in the model settings when you calculated the parameters of DAT-light?

zhengchen1999 commented 10 months ago

The script(FLOPs/Params) I use is:

import yaml
import torch
from basicsr.utils.registry import ARCH_REGISTRY
from fvcore.nn import FlopCountAnalysis

def print_model_parm_nums(model):
    total = sum([param.nelement() for param in model.parameters()])
    print('Number of params: %.2fK' % (total / 1e3))

opt_str = r"""
  type: DAT
  upscale: 4
  in_chans: 3
  img_size: 64
  img_range: 1.
  depth: [18]
  embed_dim: 60
  num_heads: [6]
  expansion_factor: 2
  resi_connection: '3conv'
  split_size: [8,32]
  upsampler: 'pixelshuffledirect'
"""

opt = yaml.safe_load(opt_str)

network_type = opt.pop('type')
net = ARCH_REGISTRY.get(network_type)(**opt)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
net.to(device).eval()
input = torch.rand(1, 3, 128, 128, device='cuda')

flops = FlopCountAnalysis(net, input)
print("FLOPs: ", flops.total())

print_model_parm_nums(net)
Amazinghaodeihao666 commented 10 months ago

Hi, 哇塞!得到您这么详细的回复实现是太荣幸了,太感谢了!!! 下面是按照您的配置文件设置的参数:

结果:

我现在也不知道自己的问题出在哪里了,希望您能给我指点迷津,谢谢啦!

At 2023-09-08 15:05:52, "Zheng Chen" @.***> wrote:

The shared parameters are just my guess. Actually, I don't actively share the model parameters in DAT. I try thop to calculate the parameters of DAT-light-x4. The code is as follows:

x = torch.randn(1, 3, 64, 64, device='cuda') import thop flops, params = thop.profile(net,inputs=(x,)) print(params)

The result:

[INFO] Register count_convNd() for <class 'torch.nn.modules.conv.Conv2d'>. [INFO] Register count_normalization() for <class 'torch.nn.modules.normalization.LayerNorm'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.container.Sequential'>. [INFO] Register count_linear() for <class 'torch.nn.modules.linear.Linear'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.dropout.Dropout'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.activation.ReLU'>. [INFO] Register count_normalization() for <class 'torch.nn.modules.batchnorm.BatchNorm2d'>. [INFO] Register count_adap_avgpool() for <class 'torch.nn.modules.pooling.AdaptiveAvgPool2d'>. [INFO] Register count_relu() for <class 'torch.nn.modules.activation.LeakyReLU'>. [INFO] Register zero_ops() for <class 'torch.nn.modules.pixelshuffle.PixelShuffle'>. 572712.0

It is consistent with the results of the paper. Are there some errors in the model settings when you calculated the parameters of DAT-light?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

zhengchen1999 commented 10 months ago

Sorry, I don't understand what you mean. Can you explain your problems again?

Amazinghaodeihao666 commented 10 months ago

Hi, 不好意思,我上一封邮件发的是截图,可能无法正确显示。下面是我的问题: model = DAT( upscale=4, in_chans=3, img_size=64, img_range=1., depth=[18], embed_dim=60, num_heads=[6], expansion_factor=2, resi_connection='3conv', split_size=[8, 32], ).eval()

x = torch.randn(1, 3, 64, 64)

flops, params = profile(model, inputs=(x,))
flops, params = clever_format([flops, params], "%.3f")
print('flops:{}', flops)
print('params:{}', params)

这是设置的模型参数,很您给的保持一致,结果是:

flops:{} 5.445G params:{} 878.523K 这和您测出来的参数量有着比较大的差异。如果可以的话,您能把您测试用的dat_arch发给我吗,麻烦您了。

在 2023-09-08 21:20:37,"Zheng Chen" @.***> 写道:

Sorry, I don't understand what you mean. Can you explain your problems again?

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>

zhengchen1999 commented 10 months ago

You don't set upsampler: 'pixelshuffledirect' in DAT. The upsampler in DAT-light is different from DAT/DAT-S (default: pixelshuffle).

Amazinghaodeihao666 commented 10 months ago

Hi, I got it. Thanks very very very much!!! Good night.

在 2023-09-08 22:24:46,"Zheng Chen" @.***> 写道:

You don't set upsampler: 'pixelshuffledirect' in DAT. The upsampler in DAT-light is different from DAT/DAT-S (default: pixelshuffle).

— Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you authored the thread.Message ID: @.***>