huawei-noah / vega

AutoML tools chain
http://www.noahlab.com.hk/opensource/vega/
Other
845 stars 176 forks source link

Loading CARS fully train model in jupyter notebook cell. #235

Closed Stasolet closed 2 years ago

Stasolet commented 2 years ago

Hello, thanks for your previous answer and your work!

How I can load fully trained CARS model in python script or jupyter notebook to variable? I have model_0.pth and desc_0.json. I try init model like in method _init_model in ModelBuilder class and got an error ValueError: can't find class type network class name PreOneStem in class registry. I use this code:

import json

import torch
import vega
from vega.common import Config
from vega.networks.model_config import ModelConfig
from vega.model_zoo import ModelZoo

model_weights_file = f'./output/fully_train/model_0.pth'
desc_file = f'./output/fully_train/desc_0.json'

with open(desc_file) as f:
    desc = json.load(f)

config = Config(ModelConfig().to_dict())
config.model_desc = desc
cars_model = ModelZoo.get_model(**config)
cars_state = torch.load(model_weights_file)
cars_model.load_state_dict(cars_state, False)

if I add import vega.modules.preprocess where defined PreOneStem this code is work. Can you point out a mistake in my work?

full traceback:

ValueError                                Traceback (most recent call last)
/tmp/ipykernel_13251/2074894908.py in <module>
      1 config = Config(ModelConfig().to_dict())
      2 config.model_desc = desc
----> 3 cars_model = ModelZoo.get_model(**config)
      4 cars_state = torch.load(model_weights_file)
      5 cars_model.load_state_dict(cars_state, False)

/opt/conda/lib/python3.7/site-packages/vega/model_zoo/model_zoo.py in get_model(cls, model_desc, pretrained_model_file, head, is_fusion, **kwargs)
     66         except Exception as e:
     67             logging.error("Failed to get model, model_desc={}, msg={}".format(model_desc, str(e)))
---> 68             raise e
     69         logging.info("Model was created.")
     70         for k, v in kwargs.items():

/opt/conda/lib/python3.7/site-packages/vega/model_zoo/model_zoo.py in get_model(cls, model_desc, pretrained_model_file, head, is_fusion, **kwargs)
     63             raise ValueError("model desc can't be None when create model.")
     64         try:
---> 65             model = NetworkDesc(model_desc).to_model()
     66         except Exception as e:
     67             logging.error("Failed to get model, model_desc={}, msg={}".format(model_desc, str(e)))

/opt/conda/lib/python3.7/site-packages/vega/networks/network_desc.py in to_model(self)
     37         else:
     38             module = ClassFactory.get_cls(ClassType.NETWORK, "Module")
---> 39         model = module.from_desc(self._desc)
     40         if not model:
     41             raise Exception("Failed to create model, model desc={}".format(self._desc))

/opt/conda/lib/python3.7/site-packages/vega/modules/operators/functions/serializable.py in from_desc(cls, desc)
    209                 if not ClassFactory.is_exists(ClassType.NETWORK, cls_name):
    210                     raise ValueError("Network {} not exists.".format(cls_name))
--> 211                 module = ClassFactory.get_instance(ClassType.NETWORK, module_desc)
    212             modules[group_name] = module
    213             module.name = str(group_name)

/opt/conda/lib/python3.7/site-packages/vega/common/class_factory.py in get_instance(cls, type_name, params, **kwargs)
    216         params_sig = sig(t_cls).parameters if isfunction(t_cls) else sig(t_cls.__init__).parameters
    217         extra_param = {k: v for k, v in _params.items() if k not in params_sig}
--> 218         return cls._create_instance_params(params_sig, _params, extra_param, t_cls)
    219 
    220     @classmethod

/opt/conda/lib/python3.7/site-packages/vega/common/class_factory.py in _create_instance_params(cls, params_sig, _params, extra_param, t_cls)
    240         except Exception as ex:
    241             logging.error("Failed to create instance:{}".format(t_cls))
--> 242             raise ex
    243 
    244     @classmethod

/opt/conda/lib/python3.7/site-packages/vega/common/class_factory.py in _create_instance_params(cls, params_sig, _params, extra_param, t_cls)
    234                 return t_cls(**_params)
    235             # fun(a, b, c=None)
--> 236             instance = t_cls(**filter_params) if filter_params else t_cls()
    237             for k, v in extra_param.items():
    238                 setattr(instance, k, v)

/opt/conda/lib/python3.7/site-packages/vega/networks/super_network.py in __init__(self, stem, cells, head, init_channels, num_classes, auxiliary, search, aux_size, auxiliary_layer, drop_path_prob)
     38             self._auxiliary_layer = auxiliary_layer
     39         # Build stems part
---> 40         self.pre_stems = ClassFactory.get_instance(ClassType.NETWORK, stem)
     41         # Build cells part
     42         c_curr = self.pre_stems.output_channel

/opt/conda/lib/python3.7/site-packages/vega/common/class_factory.py in get_instance(cls, type_name, params, **kwargs)
    210         if kwargs:
    211             _params.update(kwargs)
--> 212         t_cls = cls.get_cls(type_name, t_cls_name)
    213         if type_name != ClassType.NETWORK:
    214             return t_cls(**_params) if _params else t_cls()

/opt/conda/lib/python3.7/site-packages/vega/common/class_factory.py in get_cls(cls, type_name, t_cls_name)
    194         # verify class
    195         if not cls.is_exists(type_name, t_cls_name):
--> 196             raise ValueError("can't find class type {} class name {} in class registry".format(type_name, t_cls_name))
    197         # get class
    198         t_cls = cls.__registry__.get(type_name).get(t_cls_name)

ValueError: can't find class type network class name PreOneStem in class registry
zhangjiajin commented 2 years ago

@Stasolet

set_backend() is required.

eg.

import vega
from vega.model_zoo import ModelZoo

vega.set_backend("pytorch", "gpu")
model_desc = "<task path>/desc_0.json"
model_pth = "<task path>/model_0.pth"
model = ModelZoo.get_model(model_desc, model_pth)
Stasolet commented 2 years ago

Thanks!