automl / NASLib

NASLib is a Neural Architecture Search (NAS) library for facilitating NAS research for the community by providing interfaces to several state-of-the-art NAS search spaces and optimizers.
Apache License 2.0
512 stars 117 forks source link

Querying ZeroCost predictor in Supported Optimizers #139

Open jr2021 opened 1 year ago

jr2021 commented 1 year ago

A problem occurs in the Bananas optimizer, when an architecture outside of the search space is sampled, and the zero_cost_scores must be calculated on-the-fly using the ZeroCost

Exception has occurred: StopIteration
exception: no description
  File "/home/jhaa/NASLib/naslib/predictors/utils/pruners/measures/model_stats.py", line 7, in get_model_stats
    model_stats = tw.ModelStats(model, input_tensor_shape,
  File "/home/jhaa/NASLib/naslib/predictors/utils/pruners/predictive.py", line 141, in find_measures
    model_stats = get_model_stats(
  File "/home/jhaa/NASLib/naslib/predictors/zerocost.py", line 31, in query
    score = predictive.find_measures(
  File "/home/jhaa/NASLib/naslib/optimizers/discrete/bananas/optimizer.py", line 93, in query_zc_scores
    score = zc_method.query(arch, dataloader=zc_method.train_loader)
  File "/home/jhaa/NASLib/naslib/optimizers/discrete/bananas/optimizer.py", line 241, in _get_best_candidates
    model.zc_scores = self.query_zc_scores(model.arch)
  File "/home/jhaa/NASLib/naslib/optimizers/discrete/bananas/optimizer.py", line 232, in new_epoch
    self.next_batch = self._get_best_candidates(candidates, acq_fn)
  File "/home/jhaa/NASLib/naslib/defaults/trainer.py", line 112, in search
    self.optimizer.new_epoch(e)
  File "/home/jhaa/NASLib/naslib/runners/bbo/runner.py", line 75, in <module>
    trainer.search(resume_from="", summary_writer=writer, report_incumbent=False)

Currently, we are passing model.arch into the zc_method.query() function. This differs slightly from the Fall-School Tutorial, where a graph is being initialized, an architecture is being sampled, the graph is being parsed and then the graph object is being passed in.