facebookresearch / mbrl-lib

Library for Model Based RL
MIT License
952 stars 154 forks source link

Make `create_one_dim_tr_model` recognize subclasses of `BasicEnsemble` #183

Closed FrankTianTT closed 1 year ago

FrankTianTT commented 1 year ago

Types of changes

Motivation and Context / Related issue

In the current implementation, create_one_dim_tr_model identifies the BasicEnsemble class through the attribute _target_ of cfg.dynamics_model, and this process relies entirely on string comparison. This prevents developers from inheriting the BasicEnsemble class to achieve the specific functions they need. We should judge whether the class corresponding to _target_ is a subclass of BasicEnsemble rather than just compare the string. The corresponding issue is #182.

How Has This Been Tested (if it applies)

In tests/core/test_common_utils.py, add class CustomEnsemble inheriting BasicEnsemble and and create it using create_one_dim_tr_model. Check the correction of dynamics_model.model.in_size and dynamics_model.model.out_size then.

Checklist