Closed rk1a closed 2 years ago
If we want to use Stable Baselines 3 we need to implement a custom policy that we can pass to the algorithm: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html I looked into it a little bit, but it seems non-trivial, because you need to provide both an actor and a critic network. I guess those are already present in the VPT model, but I couldn't check this yet.
In the VPT paper they used phasic policy gradient (https://github.com/openai/phasic-policy-gradient) for the RL finetuning which is not available in Stable Baselines 3 yet.
@mschweizer is helping here
Does the VPT model come with a value network?
Do we want to use PPO from stable baselines or some other algorithm/implementation? e.g. phasic policy gradient
How to wrap MineRLAgent
into a ActorCriticPolicy
?
Next steps:
MinecraftPolicy
(lowest level of the agent)
from typing import Callable, Dict, List, Optional, Tuple, Type, Union
import gym import torch as th from torch import nn
from stable_baselines3 import PPO from stable_baselines3.common.policies import ActorCriticPolicy
class CustomNetwork(nn.Module): """ Custom network for policy and value function. It receives as input the features extracted by the feature extractor.
:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN) :param last_layer_dim_pi: (int) number of units for the last layer of the policy network :param last_layer_dim_vf: (int) number of units for the last layer of the value network """
def init( self, feature_dim: int, last_layer_dim_pi: int = 64, last_layer_dim_vf: int = 64, ): super(CustomNetwork, self).init()
# IMPORTANT:
# Save output dimensions, used to create the distributions
self.latent_dim_pi = last_layer_dim_pi
self.latent_dim_vf = last_layer_dim_vf
# Policy network
self.policy_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
)
# Value network
self.value_net = nn.Sequential(
nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
)
def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]:
"""
:return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network.
If all layers are shared, then latent_policy == latent_value
"""
return self.policy_net(features), self.value_net(features)
def forward_actor(self, features: th.Tensor) -> th.Tensor: return self.policy_net(features)
def forward_critic(self, features: th.Tensor) -> th.Tensor: return self.value_net(features)
class CustomActorCriticPolicy(ActorCriticPolicy): def init( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Callable[[float], float], net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, *args, **kwargs, ):
super(CustomActorCriticPolicy, self).__init__(
observation_space,
action_space,
lr_schedule,
net_arch,
activation_fn,
# Pass remaining arguments to base class
*args,
**kwargs,
)
# Disable orthogonal initialization
self.ortho_init = False
def _build_mlp_extractor(self) -> None: self.mlp_extractor = CustomNetwork(self.features_dim)
- what parts of the architecture are we finetuning? (see behavioral_cloning.py)
- is the architecture of the VPT policy model compatible with off-policy algorithms? Which ones?
@mschweizer @rk1a thanks! Take into account that we might want to use an off-policy algorithm to apply the PEBBLE method to BASALT, since you've mentioned the on-policy algo PPO above. We can / should of course try it with PPO first, since that might make the process easier.
EDIT: Alright, just saw the last comment:
is the architecture of the VPT policy model compatible with off-policy algorithms? Which ones?
@rk1a
...compatible with off-policy algorithms? Which ones?
Have a look at the algo used in the original Pebble paper? Might be SAC as far as I remember
@lauritowal @mschweizer Yes, PEBBLE used SAC. The architecture should be compatible with both SAC and DDPG which are both off-policy actor-critic methods. These algos are typically employed for problems with purely continuous action spaces (in PEBBLE they did MuJoCo and DMcontrol suite). The BASALT env action space is factored into a gym.spaces.Dict. Out-of-the-box the StableBaselines3 implementations do not support Dict action spaces. -> in order to use off-policy algos it seems that we need to (1) learn about how to use SAC/DDPG with factored action spaces and (2) create a new version of the StableBaselines3 implementations able to do so
RE: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/issues/64#issuecomment-1255210887
MinecraftPolicy
defines the VPT model architecture consisting of
recurrence_type="transformer"
which uses a MaskedAttention
module (multi-layer LSTMs are also supported)MinecraftAgentPolicy
pi_head
-> action logitvf_head
-> value function predictionget_log_prob_of_action
, kl_of_action_dists
, no grad helpers act
returns action for given obs, v
returns value for given obsMineRLAgent
hidden_state
which are the intermediate outputs of the Residual Recurrent blocks; it is updated every time the agent takes an actionActionTransformer
that converts between action (numpy) array format used by the policy and action dict format used in the MineRL environment
ActionTransformer
? it seems possible, not sure if it is useful compared to other solutionsRE: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/issues/64#issuecomment-1255210887 Defining a wrapper for Off-policy algorithms like SAC seems substantially harder than for on-policy like PPO: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#off-policy-algorithms
One way would be to register a custom policy like done here: https://github.com/DLR-RM/rl-baselines3-zoo/blob/feat/densemlp/utils/networks.py But the authors of SB3 warn that good knowledge of the algorithm is required, e.g. the output of the actor must be via Tanh activation for SAC
@lauritowal @mschweizer are you familiar enough with any of the off-policy algorithms that you would feel comfortable defining the custom actor/critic? In any case I think it would be a good idea to first try writing the wrapper for PPO using the template above.
In any case I think it would be a good idea to first try writing the wrapper for PPO using the template above.
I agree.
are you familiar enough with any of the off-policy algorithms that you would feel comfortable defining the custom actor/critic?
I could have a look at it after having tried PPO
@rk1a could you create a simple example, where one of our models is updated via PPO? A notebook should be enough
@rk1a could you create a simple example, where one of our models is updated via PPO? A notebook should be enough
@rk1a could you create a simple example, where one of our models is updated via PPO? A notebook should be enough
@lauritowal @mschweizer I have started writing the wrapper: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/blob/RL_policy_finetuning/sb3_policy_wrapper.py There are still some functions to be adapted and I am not entirely sure if we can make it compatible without also changing the PPO implementation itself. The current problem is that as expected the Dict action space is not supported by PPO.
The current problem is that as expected the Dict action space is not supported by PPO.
I flattened a dict observation to a simple array for a project once. Could be probably done for the action space too (?) Could you copy&paste the dict action here, please? @rk1a
@lauritowal This is the type
Dict(ESC:Discrete(2), attack:Discrete(2), back:Discrete(2), camera:Box(low=-180.0, high=180.0, shape=(2,)), drop:Discrete(2), forward:Discrete(2), hotbar.1:Discrete(2), hotbar.2:Discrete(2), hotbar.3:Discrete(2), hotbar.4:Discrete(2), hotbar.5:Discrete(2), hotbar.6:Discrete(2), hotbar.7:Discrete(2), hotbar.8:Discrete(2), hotbar.9:Discrete(2), inventory:Discrete(2), jump:Discrete(2), left:Discrete(2), pickItem:Discrete(2), right:Discrete(2), sneak:Discrete(2), sprint:Discrete(2), swapHands:Discrete(2), use:Discrete(2))
This is an example action:
OrderedDict([('ESC', array(0)), ('attack', array(0)), ('back', array(1)), ('camera', array([-148.65605, 111.41331], dtype=float32)), ('drop', array(0)), ('forward', array(1)), ('hotbar.1', array(0)), ('hotbar.2', array(0)), ('hotbar.3', array(1)), ('hotbar.4', array(1)), ('hotbar.5', array(1)), ('hotbar.6', array(1)), ('hotbar.7', array(1)), ('hotbar.8', array(1)), ('hotbar.9', array(1)), ('inventory', array(0)), ('jump', array(1)), ('left', array(0)), ('pickItem', array(0)), ('right', array(0)), ('sneak', array(1)), ('sprint', array(1)), ('swapHands', array(0)), ('use', array(1))])
I also think we should just try the flattened array approach. There even exists a converter method in MineRLAgent
to do that.
@lauritowal @mschweizer The SB3 and Gym wrappers are now functional. You can find an example of how to use them here: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/blob/RL_policy_finetuning/sb3_policy_wrapper.py#L153 I did all testing using CPU so there might be minor issues when switching to GPU.
We can now wrap our baseline policies as SB3-compatible policies (for on-policy algorithms). The next step would be to integrate the wrapped policies and MineRL envs with imitation as outlined here: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/issues/32#issuecomment-1248072360 @gekaklam in this step we also need to integrate an interface between the web app for preference collection and the PrefRL algorithm.
@rk1a awesome work !! Thanks
@rk1a can you estimate a deadline please? :) Would help with knowing if we are on time