NVlabs / trajdata

A unified interface to many trajectory forecasting datasets.
Apache License 2.0
290 stars 38 forks source link

Feature: allows custom data in the batch #10

Closed semitable closed 1 year ago

semitable commented 1 year ago

As discussed. However, it currently does not allow for a custom collate function (now just uses np.stack, which assumes identical array sizes across batch elements).
I also added batch_custom_data.py example file under examples/

Adding custom collate should be simple enough, but I am worried it might complicate the API. For example, we could have extras be a Dict[str, Tuple[Callable, Callable]] (the second callable is custom collate), or have another extras_collate_fns (but this might require the user to double-check its dictionary keys and might create a lot of boilerplate). We can leave it for now, or if you have any thoughts we can work on it.

Let me know of any needed changes.

BorisIvanovic commented 1 year ago

Closing this for now, will discuss more internally before merging it here. Thanks!