alex-petrenko / sample-factory

High throughput synchronous and asynchronous reinforcement learning
https://samplefactory.dev
MIT License
773 stars 106 forks source link

State-action value function #274

Open paLeziart opened 1 year ago

paLeziart commented 1 year ago

Hello,

I had a quick question about the form of the value function. Right now by default it is an action value function with a linear layer that receives the output of the decoder. I was wondering if it would be possible to get a state-action value function without undesired side effects. I tried an implementation with a custom model and it seems to run fine, but I'm not sure of what is happening under the hood with the subclasses so I was wondering if it can cause issues that I have not seen yet.

I tried something as follows. Most of CustomActorCriticSharedWeights is a copy-paste of ActorCriticSharedWeights, with minor changes. The network is a simple two layers MLP of size 128. The gist of it is that the core now also outputs its input in its forward function torch.cat((self.core_head(head_output), head_output), dim=1), i.e the state and action concatenated. That way the value function receives the state-action, and I just give the action part to the action parametrization self.action_parameterization(decoder_output[:, : -self.n_encoder_out]).

Do you think it is the proper way to do it? Could it have any ill-effect? Thanks for the help!

class CustomCoreWithState(ModelCore):
    def __init__(self, cfg: Config, input_size: int):
        super().__init__(cfg)
        # build custom core architecture
        n_actions = 6
        self.core_output_size = n_actions + input_size
        self.core_head = nn.Sequential(
            nn.Linear(input_size, 128),
            nonlinearity(cfg),
            nn.Linear(128, n_actions),
            nonlinearity(cfg),
        )

    def forward(self, head_output, fake_rnn_states):
        # custom forward logic
        ### Here I cat the state (encoder output) with the actions (core output) ###
        return (
            torch.cat((self.core_head(head_output), head_output), dim=1),
            fake_rnn_states,
        )

class CustomActorCriticSharedWeights(ActorCriticSharedWeights):
    def __init__(
        self,
        model_factory,
        obs_space: ObsSpace,
        action_space: ActionSpace,
        cfg: Config,
    ):
        super().__init__(model_factory, obs_space, action_space, cfg)

        self.n_encoder_out = self.encoder.get_out_size()
        self.core = CustomCoreWithState(cfg, self.n_encoder_out)

        self.decoder = model_factory.make_model_decoder_func(
            cfg, self.core.get_out_size()
        )

        decoder_out_size: int = self.decoder.get_out_size()
        self.critic_linear = nn.Linear(decoder_out_size, 1)
        self.action_parameterization = self.get_action_parameterization(
            decoder_out_size - self.n_encoder_out
        )
        self.apply(self.initialize_weights)

    def forward_tail(
        self,
        core_output,
        values_only: bool,
        sample_actions: bool,  # encoder_output: Tensor,
    ) -> TensorDict:
        decoder_output = self.decoder(core_output)
        values = self.critic_linear(decoder_output).squeeze()
        result = TensorDict(values=values)
        if values_only:
            return result

        ### Here I split decoder output to only give the action part to the action parametrization ###
        (
            action_distribution_params,
            self.last_action_distribution,
        ) = self.action_parameterization(decoder_output[:, : -self.n_encoder_out])
        result["action_logits"] = action_distribution_params

        self._maybe_sample_actions(sample_actions, result)
        return result

For indication, here's the network summary outputted in the terminal. The state is of size 25, the action is of size 6.

[2023-06-30 10:00:32,475][51057] CustomActorCriticSharedWeights(
  (obs_normalizer): ObservationNormalizer(
    (running_mean_std): RunningMeanStdDictInPlace(
      (running_mean_std): ModuleDict(
        (obs): RunningMeanStdInPlace()
      )
    )
  )
  (returns_normalizer): RecursiveScriptModule(original_name=RunningMeanStdInPlace)
  (encoder): MultiInputEncoder(
    (encoders): ModuleDict(
      (obs): MlpEncoder(
        (mlp_head): Identity()
      )
    )
  )
  (core): CustomCoreWithState(
    (core_head): Sequential(
      (0): Linear(in_features=25, out_features=128, bias=True)
      (1): ReLU()
      (2): Linear(in_features=128, out_features=6, bias=True)
      (3): ReLU()
    )
  )
  (decoder): MlpDecoder(
    (mlp): Identity()
  )
  (critic_linear): Linear(in_features=31, out_features=1, bias=True)
  (action_parameterization): ActionParameterizationContinuousNonAdaptiveStddev(
    (distribution_linear): Linear(in_features=6, out_features=6, bias=True)
  )
)
alex-petrenko commented 12 months ago

Hi @paLeziart ! First of all, to avoid any confusion, I usually refer to the value function in actor-critic methods as just "value function", or "state value function". "State-action value function", also known as Q-function, or Q(s,a) is typically found in off-policy Q-learning methods like DQN.

Could you describe what you're trying to accomplish with this?

Is this so that gradients of the critic affect the actor encoder, but not the other way around? I.e. you want separate-weights actor critic, but not entirely separate? Sorry, I'm not quite sure what this architecture is doing.

Also, just in case, there's a parameter --actor_critic_share_weights=False that enables totally separate architecture for actor & critic, but I'm sure you're aware of it!

paLeziart commented 11 months ago

Hi @alex-petrenko ! Thank you for your answer!

Sorry the confusion might be on my side. From what I understand the goal of the critic is to estimate the value function, which is the advantage function in the case of the PPO, i.e. how "good" an action is depending on the current state. Then, considering that the architecture of the network is state -> encoder -> core -> decoder -> actions -> critic_linear -> scalar value (based on the input/output dimensions), I was wondering how the single critic_linear layer was supposed to do so when it had only access to the output actions, without the state or encoder information. So what I was trying to do was to pass the state along with the actions to that last linear layer so that it can estimate the advantage of an action depending on the state.

Screenshot from 2023-07-17 18-33-33

I guess I don't understand well the behaviour of that linear critic and how it is connected to the rest.

I am aware of the actor_critic_share_weights parameter, your documentation is quite nice! :) In the case of actor_critic_share_weights=False, my interrogation would be reversed, now the critic linear has all the encoder + core + decoder + 1 linear layer to estimate the advantage based on the input state, but it has no access to the actions anymore (since now that the weights are not shared, then its decoder does not output the actions sent to the robot anymore, only the actor's decoder does).

Sorry to bother, I guess it's a simple thing I am not seeing and the critic can actually properly estimate the advantage, otherwise this whole thing would not work for the provided examples.

alex-petrenko commented 11 months ago

From what I understand the goal of the critic is to estimate the value function, which is the advantage function in the case of the PPO, i.e. how "good" an action is depending on the current state

Not exactly, "value function" is the estimate of discounted return from current moment to the end of the episode. "Advantage" is the difference between the real value that occurred (i.e. sum of all rewads) and this predicted return. I.e. positive advantage == we did better than we expected.

state -> encoder -> core -> decoder -> actions -> critic_linear -> scalar value (based on the input/output dimensions), I was wondering how the single critic_linear layer was supposed to do so when it had only access to the output actions, without the state or encoder information. So what I was trying to do was to pass the state along with the actions to that last linear layer so that it can estimate the advantage of an action depending on the state.

I think you misinterpreted some of the code.

The architecture should be the following. For shared weights actor-critic --actor_critic_share_weights=False: obs -> encoder -> core -> decoder -> value|action logits (simultaneously produced from decoder output by two separate linear layers)

For separate actor-critic (--actor_critic_share_weights=True):

obs -> (actor) encoder -> (actor) core -> (actor) decoder -> action logits
obs -> (critic) encoder -> (critic) core -> (critic) decoder -> value

You are right to note that in neither scheme the critic has access to actions. This is by design. Critic represents that state-based value function (aka "value function", or V in the literature), this is not the Q(s,a) function found in some other algorithms.

I don't have a good intuition about what happens if you add the currently sampled actions to state. This is just not something that's normally done. It is common to add the actions from the previous timestep (you can easily do it by modifying the env state), but it is not common to add actions from the current timestep.

Mathematical formulation requires the value function to predict the value given the environment state, and the actions taken by the agent are not a part of this state. Unless I'm missing something, you probably don't even want that.

If you actually want to do that (which I think you do not), you would have to add current actions to state in a rather hacky way. You'd first need to sample the actions during forward pass (when your agent decides what to do in the environment), and then during training you would need to add the same actions to state to properly estimate the gradients. Again, I would not advise you to do that.

Sorry to bother, I guess it's a simple thing I am not seeing and the critic can actually properly estimate the advantage, otherwise this whole thing would not work for the provided examples.

The advantages are calculated here given the value estimate and bootstrapped discounted return: https://github.com/alex-petrenko/sample-factory/blob/7e1e69550f4de4cdc003d8db5bb39e186803aee9/sample_factory/algo/learning/learner.py#L986

As you can see, no need to add actions to state.

Let me know if this is a satisfactory explanation. There is a good chance I didn't fully understand your idea. Good luck!

paLeziart commented 10 months ago

@alex-petrenko Thank you for your detailed answer! I was indeed a bit confused about the shared/not-shared pipelines and what the critic was trying to learn (V(s) instead of Q(s,a)) but with your explanation and a bit of reading about Actor-Critic methods and PPO theory it is now clearer.

alex-petrenko commented 10 months ago

Glad I could help!