Emerge-Lab / gpudrive

GPU-acceleration of Nocturne via Madrona
https://arxiv.org/abs/2408.01584
MIT License
232 stars 20 forks source link

`GPUDriveTorchEnv` may return obs with extremely large values #278

Closed zhengzl18 closed 4 weeks ago

zhengzl18 commented 1 month ago

The partner_observations part of obs returned by get_obs() of GPUDriveTorchEnv tends to contain some values larger than 1e20. I'm guessing it's because the vehicles are not filtered according to their valid flags, am I right? image

If so, is it more reasonable that we should specify a param called max_observed_agents, and get partner_observations by first filtering out invalid agents, and then choosing the nearest max_observed_agents agents (pad to max_observed_agents if necessary)?

daphnecor commented 1 month ago

That's correct. The very large values are due to initialization and come from padding agents, which are not valid and thus cannot be controlled.

To filter for valid agents, the easiest and most recommended approach is to use the cont_agent_mask which is already initialized if you're using the gym env, or can be created with get_controlled_agent_mask() in the same file.

    def get_controlled_agents_mask(self):
        """Get the control mask."""
        return (self.sim.controlled_state_tensor().to_torch() == 1).squeeze(
            axis=2
        )

cont_agent_mask is a boolean tensor with the shape (num_worlds, kMaxAgentCount) where kMaxAgentCount is defined in consts.hpp. For example:

env.cont_agent_mask.shape
>>> torch.Size([10, 128])

Valid agents are indicated by True, which can be controlled. After applying this filter, you should no longer see the large values. Let me know if you still notice any oddities after this!

env.get_obs()[env.cont_agent_mask]
zhengzl18 commented 1 month ago

@daphnecor Thanks for reply. Actually I did slice the obs with env.cont_agent_mask, but the problem is that the partner obs part contains every agent's observations of every other agent, with the shape (num_worlds, kMaxAgentCount, kMaxAgentCount-1, 8). After slicing on the first two axes with env.cont_agent_mask, I still got the observations of those invalid agents on axis=2.

daphnecor commented 1 month ago

Ah, I see. One simple temporary solution is to set all elements in the partner obs larger than a threshold to zero. In the meantime, we will look into this.

zhengzl18 commented 1 month ago

Looking forward to that!

aaravpandya commented 4 weeks ago

Hi @zhengzl18, I am now zero-ing out all the irrelevant partner observations from the C++ side. This should now fix your issue.

zhengzl18 commented 4 weeks ago

@aaravpandya That's great!