yuhuixu1993 / PC-DARTS

PC-DARTS:Partial Channel Connections for Memory-Efficient Differentiable Architecture Search
436 stars 108 forks source link

How was PC_DARTS_cifar got? #26

Closed bolianchen closed 4 years ago

bolianchen commented 4 years ago

Hi @yuhuixu1993,

Appreciate if you may reply to the following questions:

  1. was PC_DARTS_cifar searched by CIFAR10 or CIFAR100?
  2. was PC_DARTS_cifar the generated genotype of the last (50th) epoch? Thanks a lot!!

Best, Bolian

yuhuixu1993 commented 4 years ago

Hi, @bolianchen :

  1. PC_DARTS_cifar was searched on CIFAR-10
  2. Yes, it is the genotype in the last epoch.
whwu95 commented 4 years ago

Hi, @yuhuixu1993 I want to reproduce your model (Params(M) is 3.6), so I just run python train_search.py to search model on CIFAR-10. But in the last epoch, I get a model with 4.5M Params. May I get your original training setting?

yuhuixu1993 commented 4 years ago

@whwu95,nas can not search the same architecture in each different search run or most of them are different. However, 4.5M are still strange. would you please offer the genotypes?

bolianchen commented 4 years ago

Hi @yuhuixu1993

Thanks very much for your immediate reply!!

whwu95 commented 4 years ago

@whwu95,nas can not search the same architecture in each different search run or most of them are different. However, 4.5M are still strange. would you please offer the genotypes?

Genotype(normal=[('sep_conv_5x5', 1), ('sep_conv_5x5', 0), ('sep_conv_5x5', 2), ('sep_conv_5x5', 1), ('sep_conv_5x5', 3), ('sep_conv_5x5', 1), ('max_pool_3x3', 4), ('sep_conv_5x5', 2)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 1), ('sep_conv_5x5', 0), ('max_pool_3x3', 2), ('max_pool_3x3', 1), ('sep_conv_5x5', 3), ('sep_conv_3x3', 2), ('sep_conv_3x3', 3), ('dil_conv_3x3', 2)], reduce_concat=range(2, 6))

yuhuixu1993 commented 4 years ago

@whwu95 ,the default settings are my training settings,did you change any hyperparameters? did you try more search runs,does it happen all the time?

whwu95 commented 4 years ago

@yuhuixu1993 Yes, I haven't changed anything and it happens all the time.

whwu95 commented 4 years ago

@yuhuixu1993 I noticed that you change the '--learning_rate_min' from 0.001 to 0, so I tried again and get a model with 4.81M Params. In fact, first I try to run python train_search.py on 1080Ti and 2080Ti with out of memory (default batch size 256). So I try again on TiTanX with batch size 256 (using 11.5G memory), and I get the above results.

yuhuixu1993 commented 4 years ago

@whwu95 ,I run code I released again on 1080ti without OOM problem. What is your running environment? On 1080ti, better to use pytorch 0.3. I will check the results again after the running is finished.

whwu95 commented 4 years ago

@yuhuixu1993 Thank you for your reply! My environment is Pytorch 1.0. Anyway, I try again on TITANX to fix the OOM problem. Thank you again and Waiting for your results.

yuhuixu1993 commented 4 years ago

@whwu95 ,hi,it works well under this environment, I recommend you to use pytorch 0.3 and 1080ti. Besides sometimes the hyperparameter that the epoch begins to start to train the arch_params can somewhat control the parameters of the searched architectures.

whwu95 commented 4 years ago

@yuhuixu1993 hi, I transfer my env to pytorch0.3 & python2.7, and run python train_search.py again. However, I still get a model with 4.6M Params...... It's really strange. Attached is my log file. Could you show me your log file? Thanks a lot. log.txt

yuhuixu1993 commented 4 years ago

@whwu95 ,from your log file,I notice that your architecture began to change in the first epoch, please check that we update the architecture Parameters in the 15th epoch ! https://github.com/yuhuixu1993/PC-DARTS/blob/a2fa00ace53376ba31b59fe5de9028aa9d8be2f1/train_search.py#L156, I have uploaded the log I ran yesterday, you can find that in the first few epochs, the genotypes did not change as we did not update the arch_params. If you solved your problems, please tell me. log.txt

whwu95 commented 4 years ago

@yuhuixu1993 Hi, I solve the problem. Yesterday I try to transfer the code to Pytorch1.0, so I replace the code in model_search.py

    self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self.betas_normal = Variable(1e-3*torch.randn(k).cuda(), requires_grad=True)
    self.betas_reduce = Variable(1e-3*torch.randn(k).cuda(), requires_grad=True)

with the code

    self.alphas_normal = nn.Parameter(1e-3*torch.randn(k, num_ops))
    self.alphas_reduce = nn.Parameter(1e-3*torch.randn(k, num_ops))
    self.betas_normal = nn.Parameter(1e-3*torch.randn(k))
    self.betas_reduce = nn.Parameter(1e-3*torch.randn(k))

Due to the params defined by nn.Parameter will be included in model.parameters(), so this code makes the model update the arch_params in the first epoch instead of the 15th epoch.

OValery16 commented 4 years ago

I met the same issue for multi GPU (Variable vs nn.Parameter). The solution is the following:

1) don t call _initialize_alphas() inside your model in the init() (so during the initialization, your model doesn't have self.alphas_normal ... in model.parameters() ) 2) Create the optimizer for the main model in the train_search 3) Use model._initialize_alphas() to declare self.alphas_normal ... (there are declkared but there are not used by the optimizer) 4) Declare Architect() and the optimizer specific to alphas_normal,alphas_reduce ... via self.model.arch_parameters(). Don t forget to set requires_grad = True to alphas_normal ... with:

        for p in self.model.arch_parameters():
            p.requires_grad = True

@yuhuixu1993 if you don know how to do it, feel free to contact me. Thanks a lot for this project.

When will your paper be published ?

heartInsert commented 4 years ago

@yuhuixu1993 Hi, I solve the problem. Yesterday I try to transfer the code to Pytorch1.0, so I replace the code in model_search.py

    self.alphas_normal = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self.alphas_reduce = Variable(1e-3*torch.randn(k, num_ops).cuda(), requires_grad=True)
    self.betas_normal = Variable(1e-3*torch.randn(k).cuda(), requires_grad=True)
    self.betas_reduce = Variable(1e-3*torch.randn(k).cuda(), requires_grad=True)

with the code

    self.alphas_normal = nn.Parameter(1e-3*torch.randn(k, num_ops))
    self.alphas_reduce = nn.Parameter(1e-3*torch.randn(k, num_ops))
    self.betas_normal = nn.Parameter(1e-3*torch.randn(k))
    self.betas_reduce = nn.Parameter(1e-3*torch.randn(k))

Due to the params defined by nn.Parameter will be included in model.parameters(), so this code makes the model update the arch_params in the first epoch instead of the 15th epoch.

So did you reproduce the result in pytorch 1.0 finally ?

yuhuixu1993 commented 4 years ago

I met the same issue for multi GPU (Variable vs nn.Parameter). The solution is the following:

  1. don t call _initialize_alphas() inside your model in the init() (so during the initialization, your model doesn't have self.alphas_normal ... in model.parameters() )
  2. Create the optimizer for the main model in the train_search
  3. Use model._initialize_alphas() to declare self.alphas_normal ... (there are declkared but there are not used by the optimizer)
  4. Declare Architect() and the optimizer specific to alphas_normal,alphas_reduce ... via self.model.arch_parameters(). Don t forget to set requires_grad = True to alphas_normal ... with:
        for p in self.model.arch_parameters():
            p.requires_grad = True

@yuhuixu1993 if you don know how to do it, feel free to contact me. Thanks a lot for this project.

When will your paper be published ?

hi @OValery16, thanks for your kind contribution, I will try if it really works. Our paper has recently been accepted by ICLR 20. Thanks!