iamhankai / ghostnet.pytorch

[CVPR2020] GhostNet: More Features from Cheap Operations
https://arxiv.org/abs/1911.11907
522 stars 116 forks source link

hello! how to move the GhostModule to my net to replace the normal conv2d layer? #15

Open tianyuluan opened 4 years ago

iamhankai commented 4 years ago

Replace nn.Conv2d with GhostModule https://github.com/iamhankai/ghostnet.pytorch/blob/2c90e67d8c33c44ec1bad12c9686f645b0d4de08/ghost_net.py#L55

LexTran commented 4 years ago

@iamhankai Hi, thanks for doing such a great work. I was trying to do the same thing as thianyuluan does. I notice that you said replace nn.Conv2d with GhostModule, how does the parameters set? Should I change any of that? if my code is like:

def conv(in_channels, out_channels, kernel_size=3, padding=1, bn=True, dilation=1, stride=1, relu=True, bias=True):
    modules = [nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)]
    if bn:
        modules.append(nn.BatchNorm2d(out_channels))
    if relu:
        modules.append(nn.ReLU(inplace=True))
    return nn.Sequential(*modules)

I just need to change it into this?:

def conv(in_channels, out_channels, kernel_size=3, padding=1, bn=True, dilation=1, stride=1, relu=True, bias=True):
    modules = [GhostModule(in_channels, out_channels, kernel_size, stride, padding, dilation, bias=bias)]
    if bn:
        modules.append(nn.BatchNorm2d(out_channels))
    if relu:
        modules.append(nn.ReLU(inplace=True))
    return nn.Sequential(*modules)

Is that ok? Hope to get your help, thanks a lot!

Tangzhaotz commented 4 years ago

i have same quetion ,do you solve the problem?

iamhankai commented 4 years ago

@LexTran Just change to:

def conv(in_channels, out_channels, kernel_size=3, padding=1, bn=True, dilation=1, stride=1, relu=True, bias=True):
    modules = [GhostModule(in_channels, out_channels, kernel_size, stride=stride)]
    if bn:
        modules.append(nn.BatchNorm2d(out_channels))
    if relu:
        modules.append(nn.ReLU(inplace=True))
    return nn.Sequential(*modules)