microsoft / nni

An open source AutoML toolkit for automate machine learning lifecycle, including feature engineering, neural architecture search, model compression and hyper-parameter tuning.
https://nni.readthedocs.io
MIT License
14k stars 1.81k forks source link

New strategy.Proxyless() has inconsistent logic compared to deprecated ProxylessTrainer() #5056

Open AL3708 opened 2 years ago

AL3708 commented 2 years ago

Describe the issue: In 2.8 version strategy.Proxyless() was added in place of deprecated ProxylessTrainer(). But at new version is not possible to use target latency, which was a part of architecture loss in old version. strategy.Proxyless() uses training_step() from DartsLightningModule class. Inside this method self.model.training_step() is called twice (once for architecture parameters update, second for model weights update). But model.training_step() should have different behavior (loss function to be precise) depending on architecture/model update. Only architecture update includes loss from inference latency. Since no extra parameter is passed to model.training_step() or any internal state is changed (ex. arch/weights update flag), it's not possible to determine inside custom evaluator if it's architecture or model weights update. So it's not possible to use loss for latency which is an important feature of ProxylessNAS. Is there any chance to fix that in the future or we have to use old technique?

Environment:

ultmaster commented 2 years ago

Thanks for pointing this out. You are right that latency-loss is no longer supported in current version of ProxylessNAS.

We are planning a systematic support for latency-aware NAS, coming in v3.0 release. For now, a possible workaround is to derive the ProxylessLightningModule and overwrite the training_step.

By the way, do you think it will be better that we set a flag in evaluator indicating the current searching / training state (e.g., search_train, search_val, etc.)?

AL3708 commented 2 years ago

Thanks for clarification :) Hmm, as a framework user first I'll be looking for extra state argument passed to the training_step() or dedicated methods for arch/weights step, then I'll check evaluator flags.

ultmaster commented 2 years ago

Interesting. But unfortunately trainng_step() is a pytorch-lightning API, and therefore there's not much we can do...