megvii-model / SinglePathOneShot

MIT License
259 stars 53 forks source link

Channel search #1

Closed pawopawo closed 3 years ago

pawopawo commented 4 years ago

Thanks for your excellent work!

I seem to only see the block search, no channel search found.

ZichaoGuo commented 4 years ago

Thanks for your excellent work!

I seem to only see the block search, no channel search found.

Thanks for your attention. We will release the channel search later.

apxlwl commented 4 years ago

@githubgzc Hi, any updates about channel search?

MARMOTatZJU commented 4 years ago

@wlguan @pawopawo If you just want to implement channel search in SPOS, then the following code from Slimmable may help. https://github.com/JiahuiYu/slimmable_networks/blob/master/models/slimmable_ops.py After reading these papers, I found that the way to search #channels in SPO is similar to the way in Slimmable.

MARMOTatZJU commented 4 years ago

@wlguan @pawopawo A snippet of representative code from Slimmable is as follows:

    def forward(self, input):
        idx = FLAGS.width_mult_list.index(self.width_mult)
        self.in_channels = self.in_channels_list[idx]
        self.out_channels = self.out_channels_list[idx]
        self.groups = self.groups_list[idx]
        weight = self.weight[:self.out_channels, :self.in_channels, :, :]
        if self.bias is not None:
            bias = self.bias[:self.out_channels]
        else:
            bias = self.bias
        y = nn.functional.conv2d(
            input, weight, bias, self.stride, self.padding,
            self.dilation, self.groups)
        return y
zhangyuan1994511 commented 4 years ago

@githubgzc Thanks for your excellent work! I'm very interesting in how to search channel. When will release the channel search? looking forward your reply, Thank you!

pprp commented 3 years ago

I implement the channel search using slimmable network but found the model is difficult to converge on CIFAR100. The model I used is ResNet20. I also train the same model on MNIST and it speed 300 epochs to make sub-network converge. I tried batch normal calibration but find the training loss is dropping while the validtion accuracy is also dropping 😢