pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.36k stars 312 forks source link

[Feature Request] Use context managers to toggle the recurrent mode of RNN modules? #2562

Open thomasbbrunner opened 1 week ago

thomasbbrunner commented 1 week ago

Motivation

In TorchRL users must manually set the mode of RNN modules using the set_recurrent_mode method. This toggles between processing steps individually or processing entire sequences of steps.

We believe that this approach has some issues:

  1. Requires keeping track and maintaining two versions of the policy (yes, they have the same weights, but are still two objects).
  2. Is cumbersome when dealing with large policies with multiple sub-modules (as you have to re-implement the set_recurrent_mode for the policy).
  3. Seems to be easy to get wrong for people new to the code.

Proposed Alternative

Can we leverage context managers for this? Similar to how tensordict does with set_interaction_type.

For instance:

input = TensorDict(...)
lstm = LSTMModule(...)
mlp = MLP(...)
policy = TensorDictSequential(lstm, mlp)

# By default, the lstm would not be in recurrent mode.
policy(input)

# Recurrent mode can be activated with a context manager.
with set_recurrent_mode(True):
    policy(input)

Have you considered this approach in the past?

Potential Implementation

The set_recurrent_mode could be implemented in a similar fashion to the set_interaction_type:

_RECURRENT_MODE: bool = False

class set_recurrent_mode(_DecoratorContextManager):
    def __init__(self, mode: bool = False) -> None:
        super().__init__()
        self.mode = mode

    def __enter__(self) -> None:
        global _RECURRENT_MODE
        self.prev = _RECURRENT_MODE
        _RECURRENT_MODE = self.mode

    def __exit__(self, exc_type: Any, exc_value: Any, traceback: Any) -> None:
        global _RECURRENT_MODE
        _RECURRENT_MODE = self.prev

Potential Issues

  1. Unsure of the implications of this in distributed systems (is this thread or process safe?).
  2. Users could still forget to set this mode.

Checklist

thomasbbrunner commented 1 week ago

If there's interest in building something like this, I'd be wiling to invest time into it.

vmoens commented 1 week ago

Yep we could perfectly do this, I don't think anyone would wish to run the RNN in both modes in the same function. I like CMs but I know some people don't. The main issue is usually how to make distributed ops know what the context is, but I don't think it's a problem here.

Wanna work on this, or should I?

What's you opinion on these use cases:

policy = make_policy()

policy(input)  # what's the default?
policy.set_recurrent_mode(False)  # should raise a deprec warning?
with set_recurrent_mode(True):
    policy(input) # does the decorator overrides the mode set in the line above? 

with set_recurrent_mode(False):
    policy.set_recurrent_mode(True) # does this work?
    policy(input)

If the old API is to be deprecated, it's fine if the CM overrides the internal mode, but if we want both APIs to coexist it can be tricky to decide who should prevail

thomasbbrunner commented 1 week ago

I like CMs but I know some people don't.

The only thing I don't particularly like is the setting of a global variable. Would make sense to have a lock for it?

The main issue is usually how to make distributed ops know what the context is, but I don't think it's a problem here.

Yes, this would def not work in this setting. This is only really applicable for local operations and, tbf ideally the context where this is active should be short.

The use-case I imagine is something like this:

for _ in range(num_steps):
    # Collect batch
    td = env.rollout(100, policy)

    # Train on batch
    with set_recurrent_mode(True):
        loss = loss_module(td)

    loss.backward()
    ...

To your questions:

policy(input) # what's the default?

I'd say we should keep the current default (recurrent mode off)

policy.set_recurrent_mode(False) # should raise a deprec warning?

I'd say we could support both approaches of setting the recurrent mode. The context should be used in short-lived use-cases, while the method is better in the case of distributed systems.


with set_recurrent_mode(True):
   policy(input) # does the decorator overrides the mode set in the line above? ```

I'd say that the context takes precedence over the default recurrent mode.


with set_recurrent_mode(False):
    policy.set_recurrent_mode(True) # does this work?
    policy(input)```

Uuuh, tricky. Based on the previous statement ("context takes precedence over the default recurrent mode") I'd say that this would run, but the context would still take precedence.

I'd suggest that the method set_recurrent_mode would setting the default recurrent mode? Which can then be overridden by the context.

Maybe this should also be accompanied by a change to the interface set_recurrent_mode --> set_default_recurrent_mode.

Wanna work on this, or should I?

I'd be interested, but it might take some time due to some other high-prio work. So feel free to take it over!

vmoens commented 1 week ago

The only thing I don't particularly like is the setting of a global variable. Would make sense to have a lock for it?

Sure, whatever works!

vmoens commented 1 week ago

@thomasbbrunner one thing I learned in the PyTorch team is that having 2 similar APIs to do one thing should be avoided when possible. I like the CM better, so I'm enclined to adopt it and deprecate the other. If you think it's a viable way, I can make a PR pretty quickly.

thomasbbrunner commented 5 days ago

@vmoens

one thing I learned in the PyTorch team is that having 2 similar APIs to do one thing should be avoided when possible.

Definitely agree with that! I also prefer the CM approach, so I'd be ok with deprecating the other.

At the same time, I feel like the two APIs serve slightly different purposes. I guess it's similar to with torch.no_grad(): and the requires_grad argument for torch.Tensor and Parameter.

Maybe for some use-cases (like distributed setups) it would be beneficial to set a default recurrent mode. I don't think that the set_recurrent_mode method is ideal for this. I'd argue that a recurrent_mode argument in the constructor would work better. Either way, def something that could be implemented if needed at a later point.

thomasbbrunner commented 5 days ago

Also, feel free to work on this, I won't have much time this week!