HumanCompatibleAI / imitation

Clean PyTorch implementations of imitation and reward learning algorithms
https://imitation.readthedocs.io/
MIT License
1.29k stars 246 forks source link

RewardNet refactor #464

Open levmckinney opened 2 years ago

levmckinney commented 2 years ago

RewardNet Refactor

This issue proposes shrinking the RewardNet class to make it easier to extend and maintain. The RewardNet class acts as the base class for all pytorch reward networks and reward network wrappers. It is currently responsible for outputting rewards with gradients for training, providing processed rewards without gradients to improve RL learning and providing rewards at evaluation time via RewardFn.

Why is a refactor necessary

The current design is a bit clunky with forward, predict and predict_processed all having subtly different behavior. In particular predict_processed taking different arguments in different subclasses breaks polymorphism and is evidence that the current design is not doing enough composition.

In addition, the current abstract classes __init__ is bloated with arguments that only apply to image based environments. It should be smaller and apply equally well to tabular settings, image based settings and continuous control. This is already causing problems in down stream repos: https://github.com/HumanCompatibleAI/reward-function-interpretability/blob/891ade369ae24d845a1f71933a585b885f78188a/src/reward_preprocessing/tabular_reward_net.py#L29-L31

Refactoring reward net would also be an opportunity to solve #427 by making reward nets interface smaller and more composable. We might be able to move some responsibilities of the reward network out to RewardFn.

This would not be a small refactor; RewardNets are core to imitation. However, I believe, it is worth at least having a discussion of possible alternative designs before imitation goes to a 1.0 release.

Possible new design

class RewardNet(nn.Module):
    def forward(self, state, action, next_state, done) -> Mapping[str, Tensor]:
        ...
class RewardNetWrapper(RewardNet):
    def forward(self, state, action, next_state, done) -> Mapping[str, Tensor]:
        ...
AdamGleave commented 2 years ago

Thanks for writing this up, Lev!

Modes

We were a bit hesitant about relying on modes previously as it seemed less explicit than "call the right function" and so likely to be a source of subtle bugs. However, if we centralize the responsibility for switching modes to the RewardSampler than this may largely solve that problem.

The current modes seem a bit inflexible. In particular do we always want the choice of whether or not dropout is on (effectively the train vs eval distinction in PyTorch) to depend on whether or not the data is on-policy? What if we train the reward model only using on-policy data, but still want dropout on? I might just make these two distinct modes (we can always have convenience functions that set both).

If we do use modes let's use context managers for them always... with side-effects the code becomes difficult to reason about quickly.

Dictionary

I'm a bit fuzzy on how this would work. If I add shaping to a reward network, how should the algorithm know to pay attention to "reward_shaped" and not "reward" or whatever the original key was? Perhaps we could have a designated key "result" that can be overridden, and everything else is append-only? I'm still not sure how this would work if we applied the same wrapper twice: I guess we could add indexes to it, "reward_shaped_0", "reward_shaped_1"?

Looking later it seems the intention is to extract the relevant key using RewardSampler. I guess that could work, but many algorithms would need multiple reward samplers, so we'd have to pass the key into the algorithm?

In addition, you would always know what tensor you were getting without having to check the type of the reward network.

I don't understand this -- isn't the tensor shape and dtype the same in all cases? Why should the callee care what the type of the network is?

RewardSampler

I generally like this idea, although I do worry the indirection will make it less readable. Not sure this is actually better than having concrete implementations in the base class that do the mode futzing and call forward?

levmckinney commented 2 years ago

I don't understand this -- isn't the tensor shape and dtype the same in all cases? Why should the callee care what the type of the network is?

Sorry this was not clear. Let me give an example to clarify.

Let's say I want to get the unshaped reward from a reward network. At the moment I have to check the net to see if it is an instance of a shaped reward net wrapper and call forward on net.base or net depending on the result. With the new system I would just use the 'reward' rather than the 'reward_shaped' key. The user would know what they were getting without having to check the type of the reward network and it's wrappers.

As we add new wrappers with the current implementation it will become increasingly difficult to get intermediate results.

AdamGleave commented 2 years ago

Let's say I want to get the unshaped reward from a reward network. At the moment I have to check the net to see if it is an instance of a shaped reward net wrapper and call forward on net.base or net depending on the result. With the new system I would just use the 'reward' rather than the 'reward_shaped' key. The user would know what they were getting without having to check the type of the reward network and it's wrappers.

As we add new wrappers with the current implementation it will become increasingly difficult to get intermediate results.

Thanks for the clarification, I agree the current method of unwrapping to find a non-shaped base is hacky.

I think it's worth making concrete how this would actually work with the dict method. I believe "reward" corresponds to the base reward model before any wrappers in your convention. This isn't that hard to do with the current approach:

while hasattr(reward, 'base'):
  reward = reward.base

The tricky thing is we often don't want the base most reward model, but the first one that's not shaped. It might not be just "reward" and "reward_shaped", but "reward_affine_shaped_1". Do we want "reward", "reward_affine", "reward_affine_shaped_0", ...? If I understand correctly, the dictionary method effectively punts this problem to the user. Which works well when it's a simple, statically configured network -- but not if the network architecture depends on the config so there's not a consistent key that's the right one.

I think the dictionary approach is still on balance a bit easier to use if we often need to extract intermediate results, but I think right now we're only doing this in a single part of the codebase (AIRL.reward_test). Do you have other concrete use cases in mind for it? I think that's probably the crux for me of whether I want to move to a dictionary approach.

levmckinney commented 2 years ago

The only other concrete use case I can think of at the moment would be reward ensembles. When doing training we might want to detect if we are using an ensemble and compute the loss of each member rather than on the mean of all ensembles. At the moment in #460 I have to check the type of reward net to see if it is or is not an ensemble. This is not very robust since if the ensemble is wrapped by a shaping wrapper or a normalization wrapper it will not work. In the end, I will probably have to do the same thing as in AIRL.reward_test and strip off all the wrappers. However, with the dictionary approach I could just check to see if something like the "ensemble_ouputs" key was present and change the loss computation accordingly.

levmckinney commented 2 years ago

It would also remove the need for most of the code in serialize.py since we would not need to treat different reward functions differently based on what wrappers they have. You could check if the network had the correct functionality by dry running some data through it and seeing if it returned the right keys. Then you would simply return a reward evaluator.

AdamGleave commented 2 years ago

The only other concrete use case I can think of at the moment would be reward ensembles. When doing training we might want to detect if we are using an ensemble and compute the loss of each member rather than on the mean of all ensembles. At the moment in #460 I have to check the type of reward net to see if it is or is not an ensemble. This is not very robust since if the ensemble is wrapped by a shaping wrapper or a normalization wrapper it will not work. In the end, I will probably have to do the same thing as in AIRL.reward_test and strip off all the wrappers. However, with the dictionary approach I could just check to see if something like the "ensemble_ouputs" key was present and change the loss computation accordingly.

The ensemble is a good use case. It's sad the training algorithm needs to care at all whether or not it's an ensemble model, but I guess that is more or less unavoidable :(

How should the ensemble + shaping work conceptually? I'd think we'd want to shape each ensemble member separately, and then compute the loss of member[i]+shaping[i] separately.

levmckinney commented 2 years ago

How should the ensemble + shaping work conceptually? I'd think we'd want to shape each ensemble member separately, and then compute the loss of member[i]+shaping[i] separately.

Good point I don't think there is really a case where we will be wrapping the ensemble with a shaping function.