pytorch / rl

A modular, primitive-first, python-first PyTorch library for Reinforcement Learning.
https://pytorch.org/rl
MIT License
2.27k stars 302 forks source link

[Feature Request] `MultiDiscreteTensorSpec` in DQN #1077

Open matteobettini opened 1 year ago

matteobettini commented 1 year ago

Currently, DQN modules and losses want to know

 action_space (str, optional): The action space to be considered.
            Must be one of
            ``"one-hot"``, ``"mult_one_hot"``, ``"binary"`` or ``"categorical"``.

We need to add multi categorical to this (MultiDiscreteTensorSpec) and make it work in the modules and losses.

EDIT: following paragraph was solved

Plus, would it be possible to drop this argument and infer it from the action spec? It seems redundant to me that I need to feed my DQN loss both the spec and a string which repeats a part of the spec

vmoens commented 1 year ago

This choice is aimed at making the loss usable without requiring any other class from the lib. If we want to infer it from the policy which contains spec, we're asking the user to:

That goes against the "adopt one component and not all" principle we any to achieve. But we could accept a spec and a string in these classes though

matteobettini commented 1 year ago

could this be optional and if not present we attempt to look at the spec?

vmoens commented 1 year ago

The question is also how do you fetch it when using a regular TensorDictSequential (not SafeSequential) which has no concept of action-spec. I guess an informative error message like "you must either pass a spec, a string or a policy with a spec" could help...

vmoens commented 1 year ago

Let's keep in mind that the use case we want to support is also DQLoss(nn.Module) As long as this can be used without needing to learn anything about TensorDict I'm ok with whichever option. Some time in the future I will do a tutorial "DQN (or else) without TensorDict" I think it'd be fun!

matteobettini commented 1 year ago

Yea something like that. When you do not have the needed parameter, you try to find it as hard as you can and if in the end you do not find it you tell the user that error

matteobettini commented 1 year ago

So users of torch rl can avoid to use those parameters and external users just need to give them