thu-ml / tianshou

An elegant PyTorch deep reinforcement learning library.
https://tianshou.org
MIT License
7.84k stars 1.12k forks source link

Dict obs and custom network implementation question #696

Closed SoMuchSerenity closed 2 years ago

SoMuchSerenity commented 2 years ago

Hi,

I am working on an environment which returns a dictionary observation space, consisting of an image and some scalar variables. I have looked through all the issues and documentation, yet couldn't find a related question. PPO is the algorithm I intend to work with.

class Net(nn.Module):
    def __init__(self, observation_space: Dict):        
        super().__init__()
        weights = model.ResNet18_Weights.DEFAULT
        self.network = nn.Sequential(*list(model.resnet18(weights= weights).children())[:-1])
        self.network1 = nn.Linear(512,2000)
        self.status_process = nn.Linear(2, 128)

    def forward(self, observations) -> th.Tensor:
        x = self.network(th.Tensor(observations['image']))
        x = self.network1(x.squeeze())
        y = self.status_process(th.Tensor(observations['status'])).squeeze()
        return th.cat((x,y),0)

Generally, my pre-processing network would look like the above, where a CNN deals with image input and a MLP deals with 2 scalar inputs. After this ,

actor = Actor(net, env.action_space.n, device=device).to(device)
critic = Critic(net, device=device).to(device)

Will be used to create actor and critic. I also have one question regarding Actor definition. I have checked the source code of Actor(), the output dimension is defined as:

self.output_dim = int(np.prod(action_shape))

I don't quite understand this definition as I would consider the output dimension to be 2*action_shape if using Gaussian policy.

Thanks in advance with the help!

Trinkle23897 commented 2 years ago

I don't quite understand this definition as I would consider the output dimension to be 2*action_shape if using Gaussian policy.

actor = ActorProb(net, env.action_space.n, device=device).to(device)

See the example of test/continuous/test_ppo.py.

SoMuchSerenity commented 2 years ago

Thanks Weng. I closed the issue because I have realised the flexibility of Tianshou as I was using other RL library and it was very rigid and not flexible. I will look into the documentation and source code further. Will come back to you if I have more questions. Thanks very much for your reponse! I have also tried envpool as it is recommended in the documentation, however it is not supported at the moment on Windows. Great job on these libraries!