NVlabs / trajdata

A unified interface to many trajectory forecasting datasets.
Apache License 2.0
290 stars 38 forks source link

Issues with SceneBatch.to_agent #37

Open bmacadam-sfu opened 4 months ago

bmacadam-sfu commented 4 months ago

Hi there,

I caught the following errors in the to_agent method for SceneBatch while trying to visualize some scenes:

  1. The index_neighbors function does not preserve StateArrays/StateTensors. This can be fixed by checking if a statetensor is passed and using its formatting.
        def index_neighbors(x: Tensor | StateTensor) -> Tensor | StateTensor:
            index_neighbors = x[others_mask].reshape([batch_size, num_agents-1]+list(x.shape[2:]))
            if isinstance(x, StateTensor):
                index_neighbors = StateTensor.from_array(index_neighbors, x._format)
            return index_neighbors
    1. The index_agent function doesn't play well with map names, this can be fixed by wrapping each map_name into a singleton list:
           map_names=index_agent_list([[m] for m in self.map_names]),
bmacadam-sfu commented 4 months ago

Of course, for the index_neighbors problem, you would presumably want to fix the underlying issues with reshape and StateArrays/StateTensors.