quark0 / darts

Differentiable architecture search for convolutional and recurrent networks
https://arxiv.org/abs/1806.09055
Apache License 2.0
3.92k stars 843 forks source link

Training results of `train_search.py`? #7

Closed bkj closed 6 years ago

bkj commented 6 years ago

What's the recommended way to train the results of train_search.py? This is the end of my log.txt:

...
2018-06-27 13:25:46,378 epoch 49 lr 1.023679e-03
2018-06-27 13:25:46,379 genotype = Genotype(normal=[('skip_connect', 0), ('sep_conv_3x3', 1), ('sep_conv_3x3', 0), ('sep_conv_3x3', 2), ('sep_conv_3x3', 1), ('skip_connect', 0), ('sep_conv_3x3', 0), ('skip_connect', 1)], normal_concat=range(2, 6), reduce=[('max_pool_3x3', 0), ('max_pool_3x3', 1), ('max_pool_3x3', 0), ('skip_connect', 2), ('max_pool_3x3', 0), ('dil_conv_5x5', 3), ('skip_connect', 2), ('skip_connect', 3)], reduce_concat=range(2, 6))
...2018-06-27 13:54:15,198 train_acc 99.704000
...
2018-06-27 13:54:45,268 valid_acc 88.760000

Should I just copy and paste that Genotype into genotypes.py w/ a new name? Or is there some recommended way?

Thanks

quark0 commented 6 years ago

Yes, just copy & paste it into genotypes.py as you described : )

Then, run python train.py --arch $NAME_OF_THE_ARCH --auxiliary --cutout for evaluation.

bkj commented 6 years ago

Cool thanks -- and how would you recommend sampling a random architecture? Something like this?

model = Network(args.init_channels, CIFAR_CLASSES, args.layers, criterion)
model.alphas_normal = Variable(torch.randn(k, num_ops))
model.alphas_reduce = Variable(torch.randn(k, num_ops))
random_genotype = model.genotype()
quark0 commented 6 years ago

Yep, that should do.

bkj commented 6 years ago

OK that works -- I can verify that running train_search.py and training the resulting genotype gives comparable results to the DARTS model in cnn/genotypes.py.

avn3r commented 6 years ago

I am also getting around 89% validation accuracy when running train_search.py. Are these expected results for this step. I see no expected results for this in the documentation or paper.

quark0 commented 6 years ago

It looks fine. The validation acc during arch search does not tell too much because the weights are under-trained. You'll need to train the architecture from scratch in order to evaluate it and achieve ~2.83%.

twangnh commented 5 years ago

@quark0 hi, I'm wondering how many epochs do you train the searched architecture, I used the searched arch on epoch 49, and it it only gets 95.0 valid acc after retrained for 350 epochs