huawei-noah / vega

AutoML tools chain
http://www.noahlab.com.hk/opensource/vega/
Other
845 stars 176 forks source link

PRUNE_EA parallel_search error #87

Closed wnov closed 3 years ago

wnov commented 3 years ago

when I set parallel_search: True in prune.yml, I get this error

Traceback (most recent call last): File "", line 1, in File "/wn/vega/zeus/trainer_base.py", line 153, in train_process self._train_loop() File "/wn/vega/zeus/trainer_base.py", line 279, in _train_loop self.callbacks.before_train() File "/wn/vega/zeus/trainer/callbacks/callback_list.py", line 139, in before_train callback.before_train(logs) File "/wn/vega/vega/algorithms/compression/prune_ea/prune_trainer_callback.py", line 61, in before_train self.latency_count = calc_forward_latency(self.trainer.model, count_input, sess_config) File "/wn/vega/zeus/metrics/forward_latency.py", line 30, in calc_forward_latency step_cfg = UserConfig().data.get("nas") AttributeError: 'NoneType' object has no attribute 'get'

Its all fine to set parallel_search:False and run prune algo demo, whats wrong with parallel_search

zhangjiajin commented 3 years ago

The demo has a bug, please comment out the following code in zeus/metrics/forward_latency.py:

def calc_forward_latency(model, input, sess_config=None, num=100):
    """Model forward latency calculation.

    :param model: network model
    :type model: torch or tf module
    :param input: input tensor
    :type input: Tensor of torch or tf
    :param num: forward number
    :type num: int
    :return: forward latency
    :rtype: float
    """
    # step_cfg = UserConfig().data.get("nas")
    # if hasattr(step_cfg, "evaluator"):
    #     evaluate_cfg = step_cfg.get("evaluator")
    #     if hasattr(evaluate_cfg, "davinci_mobile_evaluator"):
    #         evaluate_config = evaluate_cfg.get("davinci_mobile_evaluator")
    #         latency = _calc_forward_latency_davinci(model, input, sess_config, evaluate_config)
    # else:
    latency = _calc_forward_latency_gpu(model, input, sess_config, num)
    return latency

Note that the value of random_models must be greater than the number of GPUs.

    search_algorithm:
        type: PruneEA
        codec: PruneCodec
        policy:
            length: 464
            num_generation: 31
            num_individual: 32
            random_models: 64

@wnov