Closed akmandor closed 2 years ago
MultiInputPolicy with a custom feature extractor is what you are looking for, yes. You can specify how each of the observation keys are treated (CNN for observation in "key1" and then concatenate it with "key2" observation). Docs have an example on how to do this: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#multiple-inputs-and-dictionary-observations
But it is not clear in docs how to create the model with a custom MultiInputPolicy.
1) When I create the model as below:
policy_kwargs = dict(features_extractor_class=CustomCombinedExtractor, features_extractor_kwargs=dict(features_dim=n_actions),)
model = PPO("MultiInputPolicy", env)
I got the following error:
"Error: unknown policy type MultiInputPolicy,the only registed policy type are: ['MlpPolicy', 'CnnPolicy']!"
2) Instead of "MultiInputPolicy", if I use the class name as below:
model = PPO(CustomCombinedExtractor, env)
I get the following error:
Traceback (most recent call last): File ".../training.py", line 335, in
model = PPO(CustomCombinedExtractor, env, learning_rate=learning_rate, n_steps=n_steps, batch_size=batch_size, ent_coef=ent_coef, tensorboard_log=tensorboard_log_path, device="cuda", verbose=1) File "/home/akmandor/.local/lib/python3.8/site-packages/stable_baselines3/ppo/ppo.py", line 95, in init super(PPO, self).init( File "/home/akmandor/.local/lib/python3.8/site-packages/stable_baselines3/common/on_policy_algorithm.py", line 76, in init super(OnPolicyAlgorithm, self).init( File "/home/akmandor/.local/lib/python3.8/site-packages/stable_baselines3/common/base_class.py", line 156, in init env = self._wrap_env(env, self.verbose, monitor_wrapper) File "/home/akmandor/.local/lib/python3.8/site-packages/stable_baselines3/common/base_class.py", line 209, in _wrap_env env = ObsDictWrapper(env) File "/home/akmandor/.local/lib/python3.8/site-packages/stable_baselines3/common/vec_env/obs_dict_wrapper.py", line 28, in init self.obs_dim = venv.observation_space.spaces["observation"].shape[0] KeyError: 'observation'
you need to upgrade your SB3 version. please format your code using markdown as shown in the issue template
Question
Let's say I have 1-D vector of observations with n (516) data. I would like to pass the first k (512) data to a CNN network. Having the output from the CNN network and concatenating with the rest of my n-k (4) data, I would like to pass them into a FC network.
Main question: What is the right way to implement this custom network within the stable-baselines3 architecture?
My approaches and side questions:
Additional context
Using the guidelines in "Custom Policy Network" in the documentation, I implemented the following custom policy:
I set my model using the Custom1DCNNPolicy as following:
However, the network is failed to learn the task as shown in the following result plot:
In order to check the validity of the data (observations), I trained only using the FC network and in that case the result is successful as in the following plot:
I also tried to train using different parameters (learning rate, channel inputs outputs, kernel sizes, etc.), but the results are very similar to the failing plot above.
Please also note that my desired network architecture has already implemented in Stable Baselines using CnnPolicy as in https://stable-baselines.readthedocs.io/en/master/misc/projects.html#train-a-ros-integrated-mobile-robot-differential-drive-to-avoid-dynamic-objects with the custom policy class given below.
In this implementation example, the input observation is the concatenated laser scan data and waypoints as a 1-D vector. The first 3 layers are defined as 1-D CNN where layer 4 and 5 are FC. The laser scan length of input observation is fed into 3 layers and then the output is concatenated with the rest of observations (1-D vectorized waypoints data) and feed into the 2 FC layers.
Checklist