automl / NASLib

NASLib is a Neural Architecture Search (NAS) library for facilitating NAS research for the community by providing interfaces to several state-of-the-art NAS search spaces and optimizers.
Apache License 2.0
512 stars 117 forks source link

Add a plotting feature for arch weights #147

Closed Louquinze closed 1 year ago

Louquinze commented 1 year ago

This PR adds a new features which allows to trace the alpha values while training

run this file to try the new feature

import os
import logging
from naslib.defaults.trainer import Trainer
from naslib.optimizers import DARTSOptimizer, GDASOptimizer, RandomSearch
from naslib.search_spaces import DartsSearchSpace, SimpleCellSearchSpace

from naslib.utils import set_seed, setup_logger, get_config_from_args

config = get_config_from_args()  # use --help so see the options
config.search.batch_size = 128
config.search.epochs = 1
config.save_arch_weights = True
config.save_arch_weights_path = f"{config.save}/save_arch"
set_seed(config.seed)

logger = setup_logger(config.save + "/log.log")
logger.setLevel(logging.INFO)  # default DEBUG is very verbose

search_space = SimpleCellSearchSpace()  # DartsSearchSpace()  # use SimpleCellSearchSpace() for less heavy search

optimizer = DARTSOptimizer(config)
optimizer.adapt_search_space(search_space)

trainer = Trainer(optimizer, config)
trainer.search()  # Search for an architecture
# trainer.evaluate()  # Evaluate the best architecture

this new feature adds 2 new variables to the config object

config.save_arch_weights = True
config.save_arch_weights_path = f"{config.save}/save_arch"

you can access the plots in the run directory