mit-han-lab / once-for-all

[ICLR 2020] Once for All: Train One Network and Specialize it for Efficient Deployment
https://ofa.mit.edu/
MIT License
1.89k stars 333 forks source link

KeyError from MBv3LatencyTable in EvolutionFinder #65

Open akvallapuram opened 2 years ago

akvallapuram commented 2 years ago

I am trying to obtain specialised deployment with MobileNetV3 architecture, but I obtain the following the error. Kindly suggest on what must be done.

The error originates from this code:

from ofa.nas.search_algorithm.evolution import EvolutionFinder
from ofa.nas.efficiency_predictor.latency_lookup_table import MBv3LatencyTable
from ofa.nas.accuracy_predictor.acc_predictor import AccuracyPredictor
from ofa.nas.accuracy_predictor.arch_encoder import MobileNetArchEncoder
from ofa.model_zoo import ofa_net

...
...
...

    self.ofa_network = ofa_net('ofa_mbv3_d234_e346_k357_w1.2', pretrained=True)
    self.arch_encoder = MobileNetArchEncoder()
    self.accuracy_predictor = AccuracyPredictor(self.arch_encoder, device=args.device)
    self.efficiency_predictor = MBv3LatencyTable()

    params = {
                'constraint_type': ‘note10’, 
                'efficiency_constraint': 33,
                'network': self.ofa_network,
                'efficiency_predictor': self.efficiency_predictor, 
                'accuracy_predictor': self.accuracy_predictor, 
                'mutate_prob': 0.1, 
                'mutation_ratio': 0.5, 
                'population_size': 100,
                'max_time_budget': 500,
                'parent_ratio': 0.25,
            }

    finder = EvolutionFinder(**params)
    best_valids, best_info = finder.run_evolution_search(params['efficiency_constraint'])

but I get the following error:

File "/home/users/akvallapuram/project/ofa.py", line 104, in __init__
    best_valids, best_info = finder.run_evolution_search(params['efficiency_constraint'])
  File "once_for_all/ofa/nas/search_algorithm/evolution.py", line 87, in run_evolution_search
    sample, efficiency = self.random_valid_sample(constraint) # returns 
  File "once_for_all/ofa/nas/search_algorithm/evolution.py", line 41, in random_valid_sample
    efficiency = self.efficiency_predictor.predict_network_latency_given_config(net_config)
  File "once_for_all/ofa/nas/efficiency_predictor/latency_lookup_table.py", line 248, in predict_network_latency_given_config
    predicted_latency += self.query(
  File "once_for_all/ofa/nas/efficiency_predictor/latency_lookup_table.py", line 195, in query
    return self.lut[key]['mean']
KeyError: 'Conv-input:224x224x3-output:112x112x24'

I print the keys of self.lut, and I find the following keys:

dict_keys(['Conv-input:224x224x3-output:112x112x32', 
'Conv_1-input:7x7x320-output:7x7x1280', 
'Logits-input:7x7x1280-output:1000', 
'expanded_conv-input:112x112x16-output:56x56x24-expand:3-kernel:3-stride:2-idskip:0', 
'expanded_conv-input:112x112x16-output:56x56x24-expand:3-kernel:5-stride:2-idskip:0', 
'expanded_conv-input:112x112x16-output:56x56x24-expand:3-kernel:7-stride:2-idskip:0', 
'expanded_conv-input:112x112x16-output:56x56x24-expand:6-kernel:3-stride:2-idskip:0', 
'expanded_conv-input:112x112x16-output:56x56x24-expand:6-kernel:5-stride:2-idskip:0', 
'expanded_conv-input:112x112x16-output:56x56x24-expand:6-kernel:7-stride:2-idskip:0', 
'expanded_conv-input:112x112x32-output:112x112x16-expand:1-kernel:3-stride:1-idskip:0', 
'expanded_conv-input:112x112x32-output:112x112x16-expand:1-kernel:5-stride:1-idskip:0', 
'expanded_conv-input:112x112x32-output:112x112x16-expand:1-kernel:7-stride:1-idskip:0', 
'expanded_conv-input:14x14x80-output:14x14x80-expand:3-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:14x14x80-output:14x14x80-expand:3-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:14x14x80-output:14x14x80-expand:3-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:14x14x80-output:14x14x80-expand:6-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:14x14x80-output:14x14x80-expand:6-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:14x14x80-output:14x14x80-expand:6-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:14x14x80-output:14x14x96-expand:3-kernel:3-stride:1-idskip:0', 
'expanded_conv-input:14x14x80-output:14x14x96-expand:3-kernel:5-stride:1-idskip:0', 
'expanded_conv-input:14x14x80-output:14x14x96-expand:3-kernel:7-stride:1-idskip:0', 
'expanded_conv-input:14x14x80-output:14x14x96-expand:6-kernel:3-stride:1-idskip:0', 
'expanded_conv-input:14x14x80-output:14x14x96-expand:6-kernel:5-stride:1-idskip:0', 
'expanded_conv-input:14x14x80-output:14x14x96-expand:6-kernel:7-stride:1-idskip:0', 
'expanded_conv-input:14x14x96-output:14x14x96-expand:3-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:14x14x96-output:14x14x96-expand:3-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:14x14x96-output:14x14x96-expand:3-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:14x14x96-output:14x14x96-expand:6-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:14x14x96-output:14x14x96-expand:6-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:14x14x96-output:14x14x96-expand:6-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:14x14x96-output:7x7x192-expand:3-kernel:3-stride:2-idskip:0', 
'expanded_conv-input:14x14x96-output:7x7x192-expand:3-kernel:5-stride:2-idskip:0', 
'expanded_conv-input:14x14x96-output:7x7x192-expand:3-kernel:7-stride:2-idskip:0', 
'expanded_conv-input:14x14x96-output:7x7x192-expand:6-kernel:3-stride:2-idskip:0', 
'expanded_conv-input:14x14x96-output:7x7x192-expand:6-kernel:5-stride:2-idskip:0', 
'expanded_conv-input:14x14x96-output:7x7x192-expand:6-kernel:7-stride:2-idskip:0', 
'expanded_conv-input:28x28x40-output:14x14x80-expand:3-kernel:3-stride:2-idskip:0', 
'expanded_conv-input:28x28x40-output:14x14x80-expand:3-kernel:5-stride:2-idskip:0', 
'expanded_conv-input:28x28x40-output:14x14x80-expand:3-kernel:7-stride:2-idskip:0', 
'expanded_conv-input:28x28x40-output:14x14x80-expand:6-kernel:3-stride:2-idskip:0', 
'expanded_conv-input:28x28x40-output:14x14x80-expand:6-kernel:5-stride:2-idskip:0', 
'expanded_conv-input:28x28x40-output:14x14x80-expand:6-kernel:7-stride:2-idskip:0', 
'expanded_conv-input:28x28x40-output:28x28x40-expand:3-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:28x28x40-output:28x28x40-expand:3-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:28x28x40-output:28x28x40-expand:3-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:28x28x40-output:28x28x40-expand:6-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:28x28x40-output:28x28x40-expand:6-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:28x28x40-output:28x28x40-expand:6-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:56x56x24-output:28x28x40-expand:3-kernel:3-stride:2-idskip:0', 
'expanded_conv-input:56x56x24-output:28x28x40-expand:3-kernel:5-stride:2-idskip:0', 
'expanded_conv-input:56x56x24-output:28x28x40-expand:3-kernel:7-stride:2-idskip:0', 
'expanded_conv-input:56x56x24-output:28x28x40-expand:6-kernel:3-stride:2-idskip:0', 
'expanded_conv-input:56x56x24-output:28x28x40-expand:6-kernel:5-stride:2-idskip:0', 
'expanded_conv-input:56x56x24-output:28x28x40-expand:6-kernel:7-stride:2-idskip:0', 
'expanded_conv-input:56x56x24-output:56x56x24-expand:3-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:56x56x24-output:56x56x24-expand:3-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:56x56x24-output:56x56x24-expand:3-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:56x56x24-output:56x56x24-expand:6-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:56x56x24-output:56x56x24-expand:6-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:56x56x24-output:56x56x24-expand:6-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:7x7x192-output:7x7x192-expand:3-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:7x7x192-output:7x7x192-expand:3-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:7x7x192-output:7x7x192-expand:3-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:7x7x192-output:7x7x192-expand:6-kernel:3-stride:1-idskip:1', 
'expanded_conv-input:7x7x192-output:7x7x192-expand:6-kernel:5-stride:1-idskip:1', 
'expanded_conv-input:7x7x192-output:7x7x192-expand:6-kernel:7-stride:1-idskip:1', 
'expanded_conv-input:7x7x192-output:7x7x320-expand:3-kernel:3-stride:1-idskip:0', 
'expanded_conv-input:7x7x192-output:7x7x320-expand:3-kernel:5-stride:1-idskip:0', 
'expanded_conv-input:7x7x192-output:7x7x320-expand:3-kernel:7-stride:1-idskip:0', 
'expanded_conv-input:7x7x192-output:7x7x320-expand:6-kernel:3-stride:1-idskip:0', 
'expanded_conv-input:7x7x192-output:7x7x320-expand:6-kernel:5-stride:1-idskip:0', 
'expanded_conv-input:7x7x192-output:7x7x320-expand:6-kernel:7-stride:1-idskip:0'])