PKU-Alignment / omnisafe

[JMLR] OmniSafe is an infrastructural framework for accelerating SafeRL research.
https://www.omnisafe.ai
Apache License 2.0
903 stars 129 forks source link

[Feature Request] Would you be providing an off-policy version of the CRPO method in the later stages? #267

Closed guanjiayi closed 1 year ago

guanjiayi commented 1 year ago

Required prerequisites

Motivation

Solution

# Copyright 2023 OmniSafe Team. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Implementation of the Lagrangian version of Soft Actor-Critic algorithm."""

import torch

from omnisafe.algorithms import registry
from omnisafe.algorithms.off_policy.sac import SAC

@registry.register
# pylint: disable-next=too-many-instance-attributes, too-few-public-methods
class SACCRPO(SAC):
    """The Off-policy CRPO algorithm.

    References:
        - Title: CRPO: A New Approach for Safe Reinforcement Learning with Convergence Guarantee.
        - Authors: Tengyu Xu, Yingbin Liang, Guanghui Lan.
        - URL: `CRPO <https://arxiv.org/pdf/2011.05869.pdf>`_.
    """

    def _init(self) -> None:
        """Initialize an instance of :class:`SACCRPO`."""

        super()._init()
        self._rew_update: int  = 0
        self._cost_update: int = 0

    def _init_log(self) -> None:
        """Log the SACCRPO specific information.

        +-----------------+--------------------------------------------+
        | Things to log   | Description                                |
        +=================+============================================+
        | Misc/RewUpdate  | The number of times the reward is updated. |
        +-----------------+--------------------------------------------+
        | Misc/CostUpdate | The number of times the cost is updated.   |
        +-----------------+--------------------------------------------+
        """
        super()._init_log()
        self._logger.register_key('Mics/RewUpdate')
        self._logger.register_key('Mics/CostUpdate')
        self._logger.register_key('Loss/loss_r_mean')
        self._logger.register_key('Loss/loss_c_max')
        self._logger.register_key('Loss/loss_c_mean')

    def _loss_pi(
        self,
        obs: torch.Tensor,
    ) -> torch.Tensor:
        r"""Compute `pi/actor` loss."""
        action = self._actor_critic.actor.predict(obs, deterministic=False)
        log_prob = self._actor_critic.actor.log_prob(action)
        loss_q_r_1, loss_q_r_2 = self._actor_critic.reward_critic(obs, action)
        loss_r = self._alpha * log_prob - torch.min(loss_q_r_1, loss_q_r_2)
        loss_c = self._actor_critic.cost_critic(obs, action)[0]

        # if loss_c.max().item() <= self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.tolerance:
        #     self._rew_update +=1
        #     loss = loss_r
        # else:
        #     self._cost_update += 1
        #     loss = loss_c

        if (loss_c.max().item()>self._cfgs.algo_cfgs.cost_limit + self._cfgs.algo_cfgs.tolerance and
            loss_c.mean().item()>loss_r.mean().item()):
            self._cost_update += 1
            loss = loss_c
        else:
            self._rew_update +=1
            loss = loss_r

        self._logger.store(
            {
                'Mics/RewUpdate': self._rew_update,
                'Mics/CostUpdate': self._cost_update,
                'Loss/loss_r_mean': loss_r.mean().item(),
                'Loss/loss_c_mean': loss_c.mean().item(),
                'Loss/loss_c_max': loss_c.max().item(),
            }
        )
        # print('tolerance:',self._cfgs.algo_cfgs.tolerance)
        return loss.mean()

    def _log_when_not_update(self) -> None:
        super()._log_when_not_update()
        self._logger.store(
            {
                'Mics/RewUpdate': self._rew_update,
                'Mics/CostUpdate': self._cost_update,
                'Loss/loss_r_mean': 0,
                'Loss/loss_c_mean': 0,
                'Loss/loss_c_max': 0,

            }
        )

Alternatives

No response

Additional context

No response

Gaiejj commented 1 year ago

We express our delight in your proactive implementation of the novel algorithm and extend our gratitude for your contributions to the advancement of safe reinforcement learning. Your implementation is commendably aligned with our stipulated criteria for the off-policy version of the CRPO algorithm. However, there are areas that warrant refinement as follows:

These suggestions are expected to significantly enhance the quality of your CRPO implementation. Should any queries or uncertainties arise, please feel free to engage in a discourse with us.

guanjiayi commented 1 year ago

Thank you for your reply, and we also extend our sincere gratitude for your valuable suggestions.