Closed zhengzl18 closed 4 weeks ago
That's correct. The very large values are due to initialization and come from padding agents, which are not valid and thus cannot be controlled.
To filter for valid agents, the easiest and most recommended approach is to use the cont_agent_mask
which is already initialized if you're using the gym env, or can be created with get_controlled_agent_mask()
in the same file.
def get_controlled_agents_mask(self):
"""Get the control mask."""
return (self.sim.controlled_state_tensor().to_torch() == 1).squeeze(
axis=2
)
cont_agent_mask
is a boolean tensor with the shape (num_worlds, kMaxAgentCount)
where kMaxAgentCount
is defined in consts.hpp
. For example:
env.cont_agent_mask.shape
>>> torch.Size([10, 128])
Valid agents are indicated by True
, which can be controlled. After applying this filter, you should no longer see the large values. Let me know if you still notice any oddities after this!
env.get_obs()[env.cont_agent_mask]
@daphnecor Thanks for reply. Actually I did slice the obs with env.cont_agent_mask
, but the problem is that the partner obs part contains every agent's observations of every other agent, with the shape (num_worlds, kMaxAgentCount, kMaxAgentCount-1, 8)
. After slicing on the first two axes with env.cont_agent_mask
, I still got the observations of those invalid agents on axis=2.
Ah, I see. One simple temporary solution is to set all elements in the partner obs larger than a threshold to zero. In the meantime, we will look into this.
Looking forward to that!
Hi @zhengzl18, I am now zero-ing out all the irrelevant partner observations from the C++ side. This should now fix your issue.
@aaravpandya That's great!
The
partner_observations
part ofobs
returned byget_obs()
ofGPUDriveTorchEnv
tends to contain some values larger than1e20
. I'm guessing it's because the vehicles are not filtered according to theirvalid
flags, am I right?If so, is it more reasonable that we should specify a param called
max_observed_agents
, and getpartner_observations
by first filtering out invalid agents, and then choosing the nearestmax_observed_agents
agents (pad tomax_observed_agents
if necessary)?