rlworkgroup / garage

A toolkit for reproducible reinforcement learning research.
MIT License
1.86k stars 309 forks source link

Torch CategoricalCNNPolicy incorrectly reinitializes module in forward pass #2231

Open jamesborg46 opened 3 years ago

jamesborg46 commented 3 years ago

In torch/modules/categorical_cnnmodule.py, the module is being reinitialized with every forward pass of the network. I believe this should be move to __init_\. Perhaps it would also be good to change the input_var parameter of categorical_cnn_policy, categorical__cnn_module, cnn_module etc., to simply something like in_channels (or perhaps rely on the observation space grabbed from env_spec similar to other examples) so we no longer need to pass an observation to the module initialization, just the channel dimensions.

    def forward(self, observations):
        """Compute the action distributions from the observations.
        Args:
            observations (torch.Tensor): Batch of observations on default
                torch device.
        Returns:
            torch.distributions.Distribution: Batch distribution of actions.
            dict[str, torch.Tensor]: Additional agent_info, as torch Tensors.
                Do not need to be detached, and can be on any device.
        """
        module = CategoricalCNNModule(
            input_var=observations,
            output_dim=self._action_dim,
            kernel_sizes=self._kernel_sizes,
            strides=self._strides,
            hidden_channels=self._hidden_conv_channels,
            hidden_sizes=self._hidden_sizes,
            hidden_nonlinearity=self._hidden_nonlinearity,
            hidden_w_init=self._hidden_w_init,
            hidden_b_init=self._hidden_b_init,
            paddings=self._paddings,
            padding_mode=self._padding_mode,
            max_pool=self._max_pool,
            pool_shape=self._pool_shape,
            pool_stride=self._pool_stride,
            output_nonlinearity=self._output_nonlinearity,
            output_w_init=self._output_w_init,
            output_b_init=self._output_b_init,
            layer_normalization=self._layer_normalization,
            is_image=self._is_image)

        dist = module(observations)
        return dist, {}
haydenshively commented 3 years ago

Hi @Indoril007!

Thanks for pointing this out! We're looking into it on our end to see whether a simple fix is possible. Our team has a few deadlines coming up, so we can't promise a fix this week. However, there will be a huge push over the next few months to improve garage's support for image-based observations, CNN policies, and visual algorithms.

In the meantime, we welcome contributions from the community, so if you want to submit a PR, please go for it!

jamesborg46 commented 3 years ago

Actually, it seems #2189 might be addressing this issue anyway