How can I use the is_terminated_term class with the is_terminated method ? I have a single envirnoment with two agents and I need to specify termination terms for each of them, not just for the envirnoment as a whole. Would this be possible using the is_terminated_term ?
def is_terminated(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize terminated episodes that don't correspond to episodic timeouts.""" return env.termination_manager.terminated.float()
`
class is_terminated_term(ManagerTermBase):
Penalize termination for specific terms that don't correspond to episodic timeouts.
The parameters are as follows:
* attr:`term_keys`: The termination terms to penalize. This can be a string, a list of strings
or regular expressions. Default is ".*" which penalizes all terminations.
The reward is computed as the sum of the termination terms that are not episodic timeouts.
This means that the reward is 0 if the episode is terminated due to an episodic timeout. Otherwise,
if two termination terms are active, the reward is 2.
def __init__(self, cfg: RewardTermCfg, env: ManagerBasedRLEnv):
# initialize the base class
super().__init__(cfg, env)
# find and store the termination terms
term_keys = cfg.params.get("term_keys", ".*")
self._term_names = env.termination_manager.find_terms(term_keys)
def __call__(self, env: ManagerBasedRLEnv, term_keys: str | list[str] = ".*") -> torch.Tensor:
# Return the unweighted reward for the termination terms
reset_buf = torch.zeros(env.num_envs, device=env.device)
for term in self._term_names:
# Sums over terminations term values to account for multiple terminations in the same step
reset_buf += env.termination_manager.get_term(term)
return (reset_buf * (~env.termination_manager.time_outs)).float()
Question
How can I use the is_terminated_term class with the is_terminated method ? I have a single envirnoment with two agents and I need to specify termination terms for each of them, not just for the envirnoment as a whole. Would this be possible using the is_terminated_term ?
def is_terminated(env: ManagerBasedRLEnv) -> torch.Tensor: """Penalize terminated episodes that don't correspond to episodic timeouts.""" return env.termination_manager.terminated.float()
` class is_terminated_term(ManagerTermBase): Penalize termination for specific terms that don't correspond to episodic timeouts.
`