Open thomasbbrunner opened 1 week ago
If there's interest in building something like this, I'd be wiling to invest time into it.
Yep we could perfectly do this, I don't think anyone would wish to run the RNN in both modes in the same function. I like CMs but I know some people don't. The main issue is usually how to make distributed ops know what the context is, but I don't think it's a problem here.
Wanna work on this, or should I?
What's you opinion on these use cases:
policy = make_policy()
policy(input) # what's the default?
policy.set_recurrent_mode(False) # should raise a deprec warning?
with set_recurrent_mode(True):
policy(input) # does the decorator overrides the mode set in the line above?
with set_recurrent_mode(False):
policy.set_recurrent_mode(True) # does this work?
policy(input)
If the old API is to be deprecated, it's fine if the CM overrides the internal mode, but if we want both APIs to coexist it can be tricky to decide who should prevail
I like CMs but I know some people don't.
The only thing I don't particularly like is the setting of a global variable. Would make sense to have a lock for it?
The main issue is usually how to make distributed ops know what the context is, but I don't think it's a problem here.
Yes, this would def not work in this setting. This is only really applicable for local operations and, tbf ideally the context where this is active should be short.
The use-case I imagine is something like this:
for _ in range(num_steps):
# Collect batch
td = env.rollout(100, policy)
# Train on batch
with set_recurrent_mode(True):
loss = loss_module(td)
loss.backward()
...
To your questions:
policy(input) # what's the default?
I'd say we should keep the current default (recurrent mode off)
policy.set_recurrent_mode(False) # should raise a deprec warning?
I'd say we could support both approaches of setting the recurrent mode. The context should be used in short-lived use-cases, while the method is better in the case of distributed systems.
with set_recurrent_mode(True): policy(input) # does the decorator overrides the mode set in the line above? ```
I'd say that the context takes precedence over the default recurrent mode.
with set_recurrent_mode(False): policy.set_recurrent_mode(True) # does this work? policy(input)```
Uuuh, tricky. Based on the previous statement ("context takes precedence over the default recurrent mode") I'd say that this would run, but the context would still take precedence.
I'd suggest that the method set_recurrent_mode
would setting the default recurrent mode? Which can then be overridden by the context.
Maybe this should also be accompanied by a change to the interface set_recurrent_mode
--> set_default_recurrent_mode
.
Wanna work on this, or should I?
I'd be interested, but it might take some time due to some other high-prio work. So feel free to take it over!
The only thing I don't particularly like is the setting of a global variable. Would make sense to have a lock for it?
Sure, whatever works!
@thomasbbrunner one thing I learned in the PyTorch team is that having 2 similar APIs to do one thing should be avoided when possible. I like the CM better, so I'm enclined to adopt it and deprecate the other. If you think it's a viable way, I can make a PR pretty quickly.
@vmoens
one thing I learned in the PyTorch team is that having 2 similar APIs to do one thing should be avoided when possible.
Definitely agree with that! I also prefer the CM approach, so I'd be ok with deprecating the other.
At the same time, I feel like the two APIs serve slightly different purposes. I guess it's similar to with torch.no_grad():
and the requires_grad
argument for torch.Tensor
and Parameter
.
Maybe for some use-cases (like distributed setups) it would be beneficial to set a default recurrent mode. I don't think that the set_recurrent_mode
method is ideal for this. I'd argue that a recurrent_mode
argument in the constructor would work better. Either way, def something that could be implemented if needed at a later point.
Also, feel free to work on this, I won't have much time this week!
Motivation
In TorchRL users must manually set the mode of RNN modules using the
set_recurrent_mode
method. This toggles between processing steps individually or processing entire sequences of steps.We believe that this approach has some issues:
set_recurrent_mode
for the policy).Proposed Alternative
Can we leverage context managers for this? Similar to how
tensordict
does withset_interaction_type
.For instance:
Have you considered this approach in the past?
Potential Implementation
The
set_recurrent_mode
could be implemented in a similar fashion to theset_interaction_type
:Potential Issues
Checklist