dragen1860 / DARTS-PyTorch

PyTorch 1.0 supported for CNN exp.
82 stars 15 forks source link

alpha should not be optimized in updating weight #5

Open yangsenius opened 5 years ago

yangsenius commented 5 years ago

https://github.com/dragen1860/DARTS-PyTorch/blob/cfcdd02cea876ce85940aa01ee58664405390fa7/model_search.py#L217

nn.Parameters() will make the alpha and beta registered to model.parameters(), so your optimizer will update the alpha and beta when optimize the weight of operations. So i think the nn.parameters() should not be used in here, which will be not consistent with the paper or original code.

dragen1860 commented 5 years ago

@yangsenius you remind me! Thank you. Have you try :+1:

self.alpha_normal = torch.randn(k, num_ops)
self.alpha_reduce = torch.randn(k, num_ops)

What's the performance when you update the code with above statement? Please tell me if you re-run the exp.

yangsenius commented 5 years ago
self.alpha_normal = torch.randn(k, num_ops)
self.alpha_reduce = torch.randn(k, num_ops)

mill make self.alpha_normal and self.alpha_reduce always be torch.floatTensor, somtimes causing error with model.cuda(), this is a little trouble. maybe


self.alpha_normal = torch.randn(k, num_ops, dtype = self.your_conv.dtype)
self.alpha_reduce = torch.randn(k, num_ops, dtype = self.your_conv.dtype)

is OK?

or

just

self.alpha_normal = nn.Parameter(torch.randn(k, num_ops)) 

def filter(model):
    for name, param in model.name_parameters():
        if 'alpha' in name:
            contiue
        yield param

optimizer = torch.optim.Adam(filter(model),)

What do you think about ? Does it have a better code implementation about this issue?

dragen1860 commented 5 years ago

since we usually set device to 'cuda:0', the

self.alpha_reduce = torch.randn(k, num_ops, dtype = torch.device("cuda"))

would be ok option. and see any problems.

@yangsenius

zh583007354 commented 5 years ago

Hi, I also noticed this problem yesterday. I think that making the parameters into two groups maybe a good choice. When training a ConvNet (ie. MobileNet), we always make the weights or parameters of conv having decay : 5e-4, but no decay for BN, so we will define optimizer = SGD([{param groups1 for conv with decay}, {param groups2 for BN without decay}])

I think we can separate alphas and weights in this way.

@dragen1860 @yangsenius

yangsenius commented 5 years ago

Yeah, you get it . @zh583007354

self.alpha_normal = nn.Parameter(torch.randn(k, num_ops)) 
def filter(model):
    for name, param in model.name_parameters():
        if 'alpha' in name:
            contiue
        yield param
optimizer = torch.optim.Adam([ {'weights':filter(model), 'alphas':model.alpha_normal}])
skx6 commented 5 years ago

It takes one hour for a epoch to search architecture. However, the paper use "a small network of 8 cells is trained using DARTS for 50 epochs. The search takes one day on a single GPU". If I train 50 epochs. It will take more than two days.

zh583007354 commented 5 years ago

@yangsenius hi, I have another question.

I want to know whether it is necessary of the clip_gradnorm() in train_search.py loss.backward() nn.utils.clip_grad_norm_(model.parameters(), args.grad_clip) optimizer.step()

If it is necessary to clip the gradient, should it be used for only weight params or all params?

Thank you.

yangsenius commented 5 years ago

I think the necessity of this clip_grad_norm_() is unknowable. Because we can't get the gradient range of the parameters, but this should be done to avoid gradient explosions (just in case), although this may not happen. So this code snippet may be useless . @zh583007354