BASALT-2022-Karlsruhe / ka-basalt-2022

MIT License
0 stars 0 forks source link

Figure out: How to finetune BC model using RL / Stable Baseline etc. [Deadline 25.9.] #64

Closed rk1a closed 2 years ago

lauritowal commented 2 years ago

@rk1a can you estimate a deadline please? :) Would help with knowing if we are on time

rk1a commented 2 years ago

If we want to use Stable Baselines 3 we need to implement a custom policy that we can pass to the algorithm: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html I looked into it a little bit, but it seems non-trivial, because you need to provide both an actor and a critic network. I guess those are already present in the VPT model, but I couldn't check this yet.

In the VPT paper they used phasic policy gradient (https://github.com/openai/phasic-policy-gradient) for the RL finetuning which is not available in Stable Baselines 3 yet.

lauritowal commented 2 years ago

@mschweizer is helping here

rk1a commented 2 years ago

import gym import torch as th from torch import nn

from stable_baselines3 import PPO from stable_baselines3.common.policies import ActorCriticPolicy

class CustomNetwork(nn.Module): """ Custom network for policy and value function. It receives as input the features extracted by the feature extractor.

:param feature_dim: dimension of the features extracted with the features_extractor (e.g. features from a CNN) :param last_layer_dim_pi: (int) number of units for the last layer of the policy network :param last_layer_dim_vf: (int) number of units for the last layer of the value network """

def init( self, feature_dim: int, last_layer_dim_pi: int = 64, last_layer_dim_vf: int = 64, ): super(CustomNetwork, self).init()

  # IMPORTANT:
  # Save output dimensions, used to create the distributions
  self.latent_dim_pi = last_layer_dim_pi
  self.latent_dim_vf = last_layer_dim_vf

  # Policy network
  self.policy_net = nn.Sequential(
      nn.Linear(feature_dim, last_layer_dim_pi), nn.ReLU()
  )
  # Value network
  self.value_net = nn.Sequential(
      nn.Linear(feature_dim, last_layer_dim_vf), nn.ReLU()
  )

def forward(self, features: th.Tensor) -> Tuple[th.Tensor, th.Tensor]: """ :return: (th.Tensor, th.Tensor) latent_policy, latent_value of the specified network. If all layers are shared, then latent_policy == latent_value """ return self.policy_net(features), self.value_net(features)

def forward_actor(self, features: th.Tensor) -> th.Tensor: return self.policy_net(features)

def forward_critic(self, features: th.Tensor) -> th.Tensor: return self.value_net(features)

class CustomActorCriticPolicy(ActorCriticPolicy): def init( self, observation_space: gym.spaces.Space, action_space: gym.spaces.Space, lr_schedule: Callable[[float], float], net_arch: Optional[List[Union[int, Dict[str, List[int]]]]] = None, activation_fn: Type[nn.Module] = nn.Tanh, *args, **kwargs, ):

  super(CustomActorCriticPolicy, self).__init__(
      observation_space,
      action_space,
      lr_schedule,
      net_arch,
      activation_fn,
      # Pass remaining arguments to base class
      *args,
      **kwargs,
  )
  # Disable orthogonal initialization
  self.ortho_init = False

def _build_mlp_extractor(self) -> None: self.mlp_extractor = CustomNetwork(self.features_dim)


- what parts of the architecture are we finetuning? (see behavioral_cloning.py)
- is the architecture of the VPT policy model compatible with off-policy algorithms? Which ones?
lauritowal commented 2 years ago

@mschweizer @rk1a thanks! Take into account that we might want to use an off-policy algorithm to apply the PEBBLE method to BASALT, since you've mentioned the on-policy algo PPO above. We can / should of course try it with PPO first, since that might make the process easier.

EDIT: Alright, just saw the last comment:

is the architecture of the VPT policy model compatible with off-policy algorithms? Which ones?

lauritowal commented 2 years ago

@rk1a

...compatible with off-policy algorithms? Which ones?

Have a look at the algo used in the original Pebble paper? Might be SAC as far as I remember

rk1a commented 2 years ago

@lauritowal @mschweizer Yes, PEBBLE used SAC. The architecture should be compatible with both SAC and DDPG which are both off-policy actor-critic methods. These algos are typically employed for problems with purely continuous action spaces (in PEBBLE they did MuJoCo and DMcontrol suite). The BASALT env action space is factored into a gym.spaces.Dict. Out-of-the-box the StableBaselines3 implementations do not support Dict action spaces. -> in order to use off-policy algos it seems that we need to (1) learn about how to use SAC/DDPG with factored action spaces and (2) create a new version of the StableBaselines3 implementations able to do so

rk1a commented 2 years ago

RE: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/issues/64#issuecomment-1255210887

rk1a commented 2 years ago

RE: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/issues/64#issuecomment-1255210887 Defining a wrapper for Off-policy algorithms like SAC seems substantially harder than for on-policy like PPO: https://stable-baselines3.readthedocs.io/en/master/guide/custom_policy.html#off-policy-algorithms

One way would be to register a custom policy like done here: https://github.com/DLR-RM/rl-baselines3-zoo/blob/feat/densemlp/utils/networks.py But the authors of SB3 warn that good knowledge of the algorithm is required, e.g. the output of the actor must be via Tanh activation for SAC

@lauritowal @mschweizer are you familiar enough with any of the off-policy algorithms that you would feel comfortable defining the custom actor/critic? In any case I think it would be a good idea to first try writing the wrapper for PPO using the template above.

lauritowal commented 2 years ago

In any case I think it would be a good idea to first try writing the wrapper for PPO using the template above.

I agree.

are you familiar enough with any of the off-policy algorithms that you would feel comfortable defining the custom actor/critic?

I could have a look at it after having tried PPO

lauritowal commented 2 years ago

@rk1a could you create a simple example, where one of our models is updated via PPO? A notebook should be enough

lauritowal commented 2 years ago

@rk1a could you create a simple example, where one of our models is updated via PPO? A notebook should be enough

lauritowal commented 2 years ago

@rk1a could you create a simple example, where one of our models is updated via PPO? A notebook should be enough

rk1a commented 2 years ago

@lauritowal @mschweizer I have started writing the wrapper: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/blob/RL_policy_finetuning/sb3_policy_wrapper.py There are still some functions to be adapted and I am not entirely sure if we can make it compatible without also changing the PPO implementation itself. The current problem is that as expected the Dict action space is not supported by PPO.

lauritowal commented 2 years ago

The current problem is that as expected the Dict action space is not supported by PPO.

I flattened a dict observation to a simple array for a project once. Could be probably done for the action space too (?) Could you copy&paste the dict action here, please? @rk1a

rk1a commented 2 years ago

@lauritowal This is the type

Dict(ESC:Discrete(2), attack:Discrete(2), back:Discrete(2), camera:Box(low=-180.0, high=180.0, shape=(2,)), drop:Discrete(2), forward:Discrete(2), hotbar.1:Discrete(2), hotbar.2:Discrete(2), hotbar.3:Discrete(2), hotbar.4:Discrete(2), hotbar.5:Discrete(2), hotbar.6:Discrete(2), hotbar.7:Discrete(2), hotbar.8:Discrete(2), hotbar.9:Discrete(2), inventory:Discrete(2), jump:Discrete(2), left:Discrete(2), pickItem:Discrete(2), right:Discrete(2), sneak:Discrete(2), sprint:Discrete(2), swapHands:Discrete(2), use:Discrete(2))

This is an example action:

OrderedDict([('ESC', array(0)), ('attack', array(0)), ('back', array(1)), ('camera', array([-148.65605,  111.41331], dtype=float32)), ('drop', array(0)), ('forward', array(1)), ('hotbar.1', array(0)), ('hotbar.2', array(0)), ('hotbar.3', array(1)), ('hotbar.4', array(1)), ('hotbar.5', array(1)), ('hotbar.6', array(1)), ('hotbar.7', array(1)), ('hotbar.8', array(1)), ('hotbar.9', array(1)), ('inventory', array(0)), ('jump', array(1)), ('left', array(0)), ('pickItem', array(0)), ('right', array(0)), ('sneak', array(1)), ('sprint', array(1)), ('swapHands', array(0)), ('use', array(1))])

I also think we should just try the flattened array approach. There even exists a converter method in MineRLAgent to do that.

rk1a commented 2 years ago

@lauritowal @mschweizer The SB3 and Gym wrappers are now functional. You can find an example of how to use them here: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/blob/RL_policy_finetuning/sb3_policy_wrapper.py#L153 I did all testing using CPU so there might be minor issues when switching to GPU.

We can now wrap our baseline policies as SB3-compatible policies (for on-policy algorithms). The next step would be to integrate the wrapped policies and MineRL envs with imitation as outlined here: https://github.com/BASALT-2022-Karlsruhe/ka-basalt-2022/issues/32#issuecomment-1248072360 @gekaklam in this step we also need to integrate an interface between the web app for preference collection and the PrefRL algorithm.

lauritowal commented 2 years ago

@rk1a awesome work !! Thanks