Open levmckinney opened 2 years ago
Thanks for writing this up, Lev!
We were a bit hesitant about relying on modes previously as it seemed less explicit than "call the right function" and so likely to be a source of subtle bugs. However, if we centralize the responsibility for switching modes to the RewardSampler
than this may largely solve that problem.
The current modes seem a bit inflexible. In particular do we always want the choice of whether or not dropout is on (effectively the train vs eval distinction in PyTorch) to depend on whether or not the data is on-policy? What if we train the reward model only using on-policy data, but still want dropout on? I might just make these two distinct modes (we can always have convenience functions that set both).
If we do use modes let's use context managers for them always... with side-effects the code becomes difficult to reason about quickly.
I'm a bit fuzzy on how this would work. If I add shaping to a reward network, how should the algorithm know to pay attention to "reward_shaped"
and not "reward"
or whatever the original key was? Perhaps we could have a designated key "result"
that can be overridden, and everything else is append-only? I'm still not sure how this would work if we applied the same wrapper twice: I guess we could add indexes to it, "reward_shaped_0"
, "reward_shaped_1"
?
Looking later it seems the intention is to extract the relevant key using RewardSampler
. I guess that could work, but many algorithms would need multiple reward samplers, so we'd have to pass the key into the algorithm?
In addition, you would always know what tensor you were getting without having to check the type of the reward network.
I don't understand this -- isn't the tensor shape and dtype the same in all cases? Why should the callee care what the type of the network is?
I generally like this idea, although I do worry the indirection will make it less readable. Not sure this is actually better than having concrete implementations in the base class that do the mode futzing and call forward
?
I don't understand this -- isn't the tensor shape and dtype the same in all cases? Why should the callee care what the type of the network is?
Sorry this was not clear. Let me give an example to clarify.
Let's say I want to get the unshaped reward from a reward network. At the moment I have to check the net to see if it is an instance of a shaped reward net wrapper and call forward on net.base or net depending on the result. With the new system I would just use the 'reward' rather than the 'reward_shaped' key. The user would know what they were getting without having to check the type of the reward network and it's wrappers.
As we add new wrappers with the current implementation it will become increasingly difficult to get intermediate results.
Let's say I want to get the unshaped reward from a reward network. At the moment I have to check the net to see if it is an instance of a shaped reward net wrapper and call forward on net.base or net depending on the result. With the new system I would just use the 'reward' rather than the 'reward_shaped' key. The user would know what they were getting without having to check the type of the reward network and it's wrappers.
As we add new wrappers with the current implementation it will become increasingly difficult to get intermediate results.
Thanks for the clarification, I agree the current method of unwrapping to find a non-shaped base is hacky.
I think it's worth making concrete how this would actually work with the dict method. I believe "reward"
corresponds to the base reward model before any wrappers in your convention. This isn't that hard to do with the current approach:
while hasattr(reward, 'base'):
reward = reward.base
The tricky thing is we often don't want the base most reward model, but the first one that's not shaped. It might not be just "reward"
and "reward_shaped"
, but "reward_affine_shaped_1"
. Do we want "reward"
, "reward_affine"
, "reward_affine_shaped_0"
, ...? If I understand correctly, the dictionary method effectively punts this problem to the user. Which works well when it's a simple, statically configured network -- but not if the network architecture depends on the config so there's not a consistent key that's the right one.
I think the dictionary approach is still on balance a bit easier to use if we often need to extract intermediate results, but I think right now we're only doing this in a single part of the codebase (AIRL.reward_test
). Do you have other concrete use cases in mind for it? I think that's probably the crux for me of whether I want to move to a dictionary approach.
The only other concrete use case I can think of at the moment would be reward ensembles. When doing training we might want to detect if we are using an ensemble and compute the loss of each member rather than on the mean of all ensembles. At the moment in #460 I have to check the type of reward net to see if it is or is not an ensemble. This is not very robust since if the ensemble is wrapped by a shaping wrapper or a normalization wrapper it will not work. In the end, I will probably have to do the same thing as in AIRL.reward_test
and strip off all the wrappers. However, with the dictionary approach I could just check to see if something like the "ensemble_ouputs"
key was present and change the loss computation accordingly.
It would also remove the need for most of the code in serialize.py
since we would not need to treat different reward functions differently based on what wrappers they have. You could check if the network had the correct functionality by dry running some data through it and seeing if it returned the right keys. Then you would simply return a reward evaluator.
The only other concrete use case I can think of at the moment would be reward ensembles. When doing training we might want to detect if we are using an ensemble and compute the loss of each member rather than on the mean of all ensembles. At the moment in #460 I have to check the type of reward net to see if it is or is not an ensemble. This is not very robust since if the ensemble is wrapped by a shaping wrapper or a normalization wrapper it will not work. In the end, I will probably have to do the same thing as in
AIRL.reward_test
and strip off all the wrappers. However, with the dictionary approach I could just check to see if something like the"ensemble_ouputs"
key was present and change the loss computation accordingly.
The ensemble is a good use case. It's sad the training algorithm needs to care at all whether or not it's an ensemble model, but I guess that is more or less unavoidable :(
How should the ensemble + shaping work conceptually? I'd think we'd want to shape each ensemble member separately, and then compute the loss of member[i]+shaping[i] separately.
How should the ensemble + shaping work conceptually? I'd think we'd want to shape each ensemble member separately, and then compute the loss of member[i]+shaping[i] separately.
Good point I don't think there is really a case where we will be wrapping the ensemble with a shaping function.
RewardNet Refactor
This issue proposes shrinking the
RewardNet
class to make it easier to extend and maintain. The RewardNet class acts as the base class for all pytorch reward networks and reward network wrappers. It is currently responsible for outputting rewards with gradients for training, providing processed rewards without gradients to improve RL learning and providing rewards at evaluation time viaRewardFn
.Why is a refactor necessary
The current design is a bit clunky with forward, predict and
predict_processed
all having subtly different behavior. In particularpredict_processed
taking different arguments in different subclasses breaks polymorphism and is evidence that the current design is not doing enough composition.In addition, the current abstract classes
__init__
is bloated with arguments that only apply to image based environments. It should be smaller and apply equally well to tabular settings, image based settings and continuous control. This is already causing problems in down stream repos: https://github.com/HumanCompatibleAI/reward-function-interpretability/blob/891ade369ae24d845a1f71933a585b885f78188a/src/reward_preprocessing/tabular_reward_net.py#L29-L31Refactoring reward net would also be an opportunity to solve #427 by making reward nets interface smaller and more composable. We might be able to move some responsibilities of the reward network out to RewardFn.
This would not be a small refactor; RewardNets are core to imitation. However, I believe, it is worth at least having a discussion of possible alternative designs before imitation goes to a 1.0 release.
Possible new design
Eval
OnPolicy
Training
Have a convention that reward wrappers should add to the dictionary rather than assign. For example a shaping wrapper would add a key like
"reward_shaped"
.def call(self, state, action, next_state, done) -> np.ndarray: ...