graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
90 stars 92 forks source link

ParquetDataset class missing 'labels' argument #729

Closed OscarBarreraGithub closed 1 month ago

OscarBarreraGithub commented 3 months ago

In the example script 02_train_tito_model.py, there are labels passed to the dataloader

'labels': { "direction": Direction( azimuth_key="injection_azimuth", zenith_key="injection_zenith" ) }

The SQLiteDataset handles these labels fine - but when loading with ParquetDataset, we get

TypeError: ParquetDataset.__init__() got an unexpected keyword argument 'labels'

This occurs when using the GraphNetDataModule, although I believe the error lies within the ParquetDataset class itself as it does not pass labels as one of its arguments (despite the Dataset class having it within its init as "labels: Dictionary of labels to be added to the dataset").

RasmusOrsoe commented 3 months ago

@OscarBarreraGithub thanks for sharing this.

It looks like it's just a simple oversight on the list of arguments for ParquetDataset - its missing the entry labels: Optional[Dict[str, Any]] = None , which should also be passed to super() here.

Fixing this would require just two lines of code:

  1. Add the argument to ParquetDataset: labels: Optional[Dict[str, Any]] = None
  2. Propagate the argument to super() here
RasmusOrsoe commented 1 month ago

closed by #730