facebookresearch / mbrl-lib

Library for Model Based RL
MIT License
959 stars 158 forks source link

Make `create_one_dim_tr_model` recognize subclasses of `BasicEnsemble` #182

Closed FrankTianTT closed 1 year ago

FrankTianTT commented 1 year ago

🚀 Feature Request

We should judge whether the class corresponding to _target_ is a subclass of BasicEnsemble in create_one_dim_tr_model, rather than just compare the string.

Motivation

In the current implementation, create_one_dim_tr_model identifies the BasicEnsemble class through the attribute _target_ of cfg.dynamics_model, and this process relies entirely on string comparison. This prevents developers from inheriting the BasicEnsemble class to achieve the specific functions they need. We should judge whether the class corresponding to _target_ is a subclass of BasicEnsemble rather than just compare the string.

Pitch

Due to hydra version limitations, currently it seems that only the hydra.utils._locate function can be used, although it is against the specification to use the protected function directly.

Let's replace

    if model_cfg._target_ == "mbrl.models.BasicEnsemble":
        model_cfg = model_cfg.member_cfg

with

    if issubclass(hydra.utils._locate(model_cfg._target_), mbrl.models.BasicEnsemble):
        model_cfg = model_cfg.member_cfg

I will create a PR to submit my code.

Additional context

No Additional context.