NVlabs / trajdata

A unified interface to many trajectory forecasting datasets.
Apache License 2.0
290 stars 38 forks source link

KeyError when loading Waymo data #44

Open bshah-tri opened 1 month ago

bshah-tri commented 1 month ago

Hi everyone,

I am encountering an issue while trying to load the Waymo motion dataset scenarios using trajdata. When I am attempting to load the dataset (10 samples), I run into an error related to agent_index not being found in the pandas dataframe:

Calculating Agent Data (Serially):   0%|                                              | 0/2989 [00:00<?, ?it/s]
Traceback (most recent call last):
  File "examples/batch_example.py", line 38, in <module>
    main()
  File "examples/batch_example.py", line 12, in main
    dataset = UnifiedDataset(
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/trajdata/dataset.py", line 347, in __init__
    scene_paths: List[Path] = self._preprocess_scene_data(
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/trajdata/dataset.py", line 954, in _preprocess_scene_data
    scene: Scene = agent_utils.get_agent_data(
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/trajdata/utils/agent_utils.py", line 47, in get_agent_data
    agent_list, agent_presence = raw_dataset.get_agent_info(
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/trajdata/dataset_specific/waymo/waymo_dataset.py", line 300, in get_agent_info
    all_agent_data_df.drop(index=agents_to_remove, inplace=True)
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/pandas/util/_decorators.py", line 331, in wrapper
    return func(*args, **kwargs)
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/pandas/core/frame.py", line 5399, in drop
    return super().drop(
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/pandas/util/_decorators.py", line 331, in wrapper
    return func(*args, **kwargs)
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/pandas/core/generic.py", line 4505, in drop
    obj = obj._drop_axis(labels, axis, level=level, errors=errors)
  File "/home/bhavya.shah/Desktop/git/testenv/lib/python3.8/site-packages/pandas/core/generic.py", line 4575, in _drop_axis
    raise KeyError(f"{labels} not found in axis")
KeyError: '[68 81 82 84 85 86 87 88 89 90 91 92 93 94 96 98 101 102 103 104 105 106\n 107 108 110 111 112 113 114 116 118 120 121 123 124 125 128 129 130 131\n 132 134 136 137 138 139 142 2705 2706 2707 143 144] not found in axis'

Origination The line from where error is originating is below from src/trajdata/dataset_specific/waymo/waymo_dataset.py file:

all_agent_data_df.drop(index=agents_to_remove, inplace=True)

Observation It seems that agents_to_remove list has indexes that are not present in the all_agent_data_df dataframe. I observed that all_agent_data_df is updated using the below code that causes some of the agent_ids to be missing in dataframe:

# This does exactly the same as dropna(...), but we're keeping the mask around
# for later use with agent_ids.
mask = pd.notna(all_agent_data_df).all(axis=1, bool_only=False)
all_agent_data_df = all_agent_data_df.loc[mask]

Workaround I was able to workaround the issue by replacing the line mentioned in Origination with below. Need to check in more detail whether this is appropriate:

try:
    for agent in agents_to_remove:
        all_agent_data_df.drop(index=[agent], inplace=True)
except:
    # agent_id is missing from the all_agent_data_df already
    pass

Any guidance or help for the above would be useful. Thank you!

Zhenlin-Xu commented 1 week ago

Hi, I encounter the same issue. Do you solve the problem?