graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
90 stars 92 forks source link

Adopt pytorch lightning DataModule instead of DataLoaders and Dataset #579

Closed AMHermansen closed 7 months ago

AMHermansen commented 1 year ago

It would be beneficial to go away from using native pytorch Dataset/Dataloader and use pytorch lightning DataModule instead(see here) , since this would allow GraphNeT to better interact with the rest of the lightning framework. In particular we would get access to the lightning profiling tool (see here), which could shed some light on where it might make sense to make optimizations.

Describe the solution you'd like Refactor Dataset and DataLoader related classes to be DataModules instead.

Additional context Doing this would seem to solve #275

What are your thoughts @RasmusOrsoe

RasmusOrsoe commented 1 year ago

I'm OK with such a refactor if

  1. We get new functionality from it. E.g. profiling (But is it worth the trouble?)
  2. We do not loose the ability to loop over batches in dataloaders in native pytorch i.e.
    for batch in DataLoader:
       pred = model(batch)
AMHermansen commented 1 year ago

It should be straight forward to get whichever Dataloader you wish (train / validation / test / prediction) using the DataModule's methods train_dataloader etc. You would simply do:

train_dataloader = GraphNeTDataModule.train_dataloader()

samadpls commented 1 year ago

Hey @AMHermansen can i work on this issue?

AMHermansen commented 1 year ago

Hello @samadpls

You are more than welcome to work on this issue. Right now however it doesn't seem clear to me what actually needs to be changed before this issue is resolved, therefore I think a good first step is to make an outline of exactly what needs to be changed. Some initial ideas would be:

  1. Implement a DataModule class. Does this class require it's own config, or is the config managed by the DataModule's Datasets?
  2. We need to figure out how the DataModule class should interact with the StandardModel class. In particular, currently we're only implementing train_step/validation_step out of the four steps train/validation/test/predict, should we also implement test_step/predict_step? Most of our logic involving prediction seems to be targeted at testing.

Maybe for a start we should be satisfied only implementing the first step, and then we can consider how much of StandardModel (if anything at all) needs to be refactored, before we have harmonious interaction between the DataModule and the StandardModel.

@RasmusOrsoe @samadpls What are your thoughts?

RasmusOrsoe commented 1 year ago

@samadpls I think it's wonderful that you're considering to take this on. But I'd also like to point out that this might not be the easiest issue to work on. As @AMHermansen points out, data loading is a central part of the library and ofc. has to play well with existing functionality.

I've taken a closer look at the profiling page and from what I gather purely off the docs, we should already have the ability to profile at least 2/3's of the library using pytorch lightning. What might be missing is detailed profiling of the Dataset and GraphDefinition - I suspect they might show up as a combined time in the advanced profiling mode.

I think we should check what level of detail we get in profiling today if we pass profiler="advanced" and profiler="simple" to Model.fit in GraphNeT. It should pass those arguments directly to Trainer. Specifically I think we should address

  1. Does that work at all?
  2. What level of detail do we get out of the profiling and is that sufficient?

If the answer to both these questions is yes, then I think we can close this issue. @AMHermansen Have you tried this?

samadpls commented 1 year ago

Hello @RasmusOrsoe , @AMHermansen Based on the discussion so far, it seems like initializing the DataModule class is a significant step in the right direction. 😋

I also want to apologize for not mentioning this earlier, but I have an upcoming vacation starting this Monday, and I won't be available for the next 20 days. During this period, I might not be able to work actively due to my traveling commitments. However, I'm committed to making the most of the time I have before my trip and, once I return, I'd be excited to continue contributing to this project.

AMHermansen commented 1 year ago

@RasmusOrsoe It is currently possible to pass various profiler arguments to Model.fit, however the profiling contains only information about the StandardModel and not information about Dataset. I can send you the log if you want to have a look yourself, the advanced option was rather verbose so I might have missed it (But Ctrl+F only found one match for "Dataset").

So I to answer your questions:

  1. Yes we can profile, but only the StandardModel and Callbacks.
  2. We only get profiling information once the data enters the StandardModel, and hence we're not able to figure out if preprocessing or I/O is a limiting factor.

Depending on how big we expect a PR to solve this issue is, it might make sense to make a separate branch, so multiple people can more easily add code?

RasmusOrsoe commented 1 year ago

@AMHermansen Thats great news. Please post the logs here for reference for both simple and advanced profiling. I think the time it takes to grab a batch should be in there somewhere. I can probably identify it.

AMHermansen commented 1 year ago

Advanced logs: perf_logs_adv.txt Simple logs: perf_logs_simple.txt

RasmusOrsoe commented 1 year ago

@AMHermansen Thank you very much for these logs. Below is a snapshot of the simple log:

image

Values seems to be shown as cumulative. I believe[_TrainingEpochLoop].train_dataloader_next is a profiling of the time spent grabbing a batch - which would include time spent in Dataset and GraphDefinition. From this value alone we can indeed see if this part of the code is a bottleneck or not. It does not however provide details into those two classes.

I personally think this is sufficient; If one wants a detailed profiling of Dataset or GraphDefinition one could use cProfile and pytorch's own profiler in sessions that are independent of the training itself, which should be simpler anyway.

Is there some additional functionality we could stand to gain from this suggested refactor that could merit this seemingly large change? If not, I think perhaps we instead should add a section on code profiling to the getting_started.md that contains a few snippets and explanation. What do you think @AMHermansen

AMHermansen commented 1 year ago

I didn't spot those values on my brief look through the profiling files.

While we might be able to see if it is a bottleneck, if we ever arrive at a situation where it is, it will be difficult to figure out which functions need to be optimized, which is a big part of the reason for having a profiler.

I'm not sure we need to make a huge refactor, it might be sufficient to make a very lightweight DataModule wrapper around our already existing Dataset/DataLoader classes. Such a DataModule might be sufficient for the profiling tool to "look into" the datasets.

The only other functionality that I'm aware of would be the LightningCLI, which could reduce quite a bit of the argparser boilerplate in our training scripts. I do none the less think it would be a good idea to add a part about profiling to the README.

RasmusOrsoe commented 10 months ago

@AMHermansen to follow up on our conversation yesterday, I've included some very specific pseudo-code that details how I think we could use the DataModule from lightning. While I initially thought this was an unnecessary addition to the library, I've found that this module can simplify and formalize the many smaller utility functions we currently have in graphnet.training.utils.py - so thank you for that!

I think the most elegant solution would be a GraphNeTDataModule that is independent of the data-backend in the Dataset class. This class would be in charge of "realizing" the choices made by the user, so it's main functionality would include:

  1. Instantiating datasets
  2. Splitting/handling selections
  3. Creation of DataLoaders.
  4. Basic sanity checks

I converged on a syntax like so:

from graphnet... import GraphNeTDataModule
from graphnet... import SQLiteDataset

dataset_reference = SQLiteDataset # should not be instantiated, just referenced
dataset_args = {dataset_path = 'mydatabase.db',
                pulsemap = 'my_pulsemap',
                truth = [...],
                ..
                graph_definition = ..,
                }
dataloader_args = {'batch_size': 32, # Should not contain `dataset = ..`
                   'num_workers': 10,
                   ...,
                   'collate_fn': my_func}

dm = GraphNeTDataModule(dataset_reference = dataset_reference,
                        selection = [0,10,100,21,23,60], # events used for train/val,
                        test_selection = [90,54,20], # events used for testing,
                        dataset_args = dataset_args,
                        dataloader_args = dataloader_args)

Using GraphNeTDataModule on parquet-files should be as easy as changing the dataset reference to ParquetDataset and pointing the corresponding files in dataset_args['path'].

For handling use-cases where we want to mix multiple datasets in EnsembleDataset, the syntax should be intuitive like:

# Multiple data sources (EnsembleDataset)

dataset_args = {dataset_path = ['database-1.db', 'database-2.db', .. 'database-n.db'],
                pulsemap = 'my_pulsemap',
                truth = [...],
                ..
                graph_definition = ..,
                }
dm = GraphNeTDataModule(dataset_reference = dataset_reference,
                        selection = [[0,10,100,21,23,60] .... [68,213,51,4,12,5]], # multiple selections passed
                        test_selection = [[90,54,20], [2131,4,1,23,1,1]], # # multiple selections passed
                        dataset_args = dataset_args,
                        dataloader_args = dataloader_args)

From those considerations, I arrived at the following pseudo-code for GraphNeTDataModule:

from typing import Dict, Any, Optional, List, Tuple, Union
import lightning as L
from torch.utils.data import DataLoader
from copy import deepcopy
from sklearn.model_selection import train_test_split
import pandas as pd

from graphnet.data.dataset import Dataset, EnsembleDataset, SQLiteDataset, ParquetDataset
from graphnet.utilities.logging import Logger

class GraphNeTDataModule(L.LightningDataModule, Logger):
    """ General Class for DataLoader Construction."""

    def __init__(self,
                 dataset_reference: Union[SQLiteDataset, 
                                          ParquetDataset,
                                          Dataset],
                 selection: Optional[Union[List[int], List[List[int]]]],
                 test_selection: Optional[Union[List[int], List[List[int]]]],
                 dataloader_args: Dict[str, Any],
                 dataset_args: Dict[str, Any],
                 train_val_split: Optional[List[float, float]] = [0.9, 0.10],
                 split_seed: int = 42) -> None:
        """Create dataloaders from dataset.

        Args:
            dataset_reference: A non-instantiated reference to the dataset class. 
            selection: (Optional) a list of event id's used for training and validation.
            test_selection: (Optional) a list of event id's used for testing.
            dataloader_args: Arguments for torch.utils.data.DataLoader.
                Should not contain `dataset`, as this is set by this class.
                See https://pytorch.org/docs/stable/data.html#torch.utils.data.DataLoader 
            dataset_args: Arguments to instantiate graphnet.data.dataset.Dataset with.
            split_seed: seed used for shuffling and splitting selections into train/validation.
        """
        self._dataset = dataset_reference
        self._selection = selection
        self._train_val_split = train_val_split
        self._test_selection = test_selection
        self._dataloader_args = dataloader_args
        self._dataset_args = dataset_args
        self._rng = split_seed

        # If multiple dataset paths are given, we should use EnsembleDataset
        self._use_ensemble_dataset = isinstance(self._dataset_args['path'], list)

    def prepare_data(self) -> None:
        ###  Download method for curated datasets. Method for download is
        # likely dataset-specific, so we can leave it as-is
        pass

    def setup(self, stage: str) -> None:
        """Prepare Datasets for DataLoaders.

        Args:
            stage: lightning stage. Either "fit, validate, test, predict"
        """
        # Sanity Checks
        self._validate_dataset_class()
        self._validate_dataset_args()
        self._validate_dataloader_args()

        # Case-handling of selection arguments
        self._resolve_selections()

        # Creation of Datasets
        self._train_dataset = self._create_dataset(self._train_selection)
        self._val_dataset = self._create_dataset(self._val_selection)
        self._test_dataset = self._create_dataset(self._test_selection) 
        return

    def train_dataloader(self) -> DataLoader:
        return self._create_dataloader(self._train_dataset)

    def val_dataloader(self) -> DataLoader:
        return self._create_dataloader(self._val_dataset)

    def test_dataloader(self) -> DataLoader:
        return self._create_dataloader(self._test_dataset)

    def teardown(self) -> None:
        # use this to shutdown sqlite connections after training?
        pass

    def _create_dataloader(self, dataset: Union[Dataset, EnsembleDataset]) -> DataLoader:
        return DataLoader(dataset=dataset, **self._dataloader_args)

    def _validate_dataset_class(self) -> None:
        """Sanity checks on the dataset reference (self._dataset). 
            Is it a GraphNeT-comaptible dataset?
            has the class already been instantiated?
            Did they try to pass EnsembleDataset?
        """
        return

    def _validate_dataset_args(self) -> None:
        """Sanity checks on the arguments for the dataset reference.
        """
        if isinstance(self._dataset_args['path'], List):
            if self._selection is not None:
                try:
                    # Check that the number of dataset paths is equal to the
                    # number of selections given as arg.
                    assert len(self._dataset_args['path']) == len(self._selection)
                except AssertionError as e:
                    self.error(f"""The number of dataset paths 
                               ({len(self._dataset_args['path'])})
                               does not match the number of selections
                               ({len(self._selection)}).""")
                    raise e

            if self._test_selection is not None:
                try:
                    # Check that the number of dataset paths is equal to the
                    # number of test selections.
                    assert len(self._dataset_args['path']) == len(self._test_selection)
                except AssertionError as e:
                    self.error(f"""The number of dataset paths 
                               ({len(self._dataset_args['path'])})
                               does not match the number of test selections
                               ({len(self._test_selection)}). If you'd like to
                               test on only a subset of the {len(self._dataset_args['path'])}
                                datasets, please provide empty test selections
                                for the others.""")
                    raise e
    def _validate_dataloader_args(self) -> None:
        """Sanity check on `dataloader_args`.
        """
        try:
            assert 'dataset' not in self._dataloader_args.keys()
        except AssertionError as e:
            self.error(f" `dataloader_args` must not contain `dataset`")
            raise e
        return

    def _resolve_selections(self) -> None:
        if self._test_selection is None:
            self.warning_once(f"""{self.__class__.__name__} did not recieve an 
                              argument for `test_selection` and will therefore
                              not have a prediction dataloader available.""")
        if self._selection is not None:
            # Split the selection into train/validation
            if self._use_ensemble_dataset:
                # Split every selection
                self._train_selection = []
                self._val_selection = []
                for selection in self._selection:
                    train_selection, val_selection = self._split_selection(self, selection = selection)
                    self._train_selection.append(train_selection)
                    self._val_selection.append(val_selection)

            else:
                # Split the only selection we got
                self._train_selection, self._val_selection = self._split_selection(self, 
                                                                                   selection = self._selection)

        if self._selection is None:
            # If not provided, we infer it by grabbing all event ids in dataset.
            self.info(f"""{self.__class__.__name__} did not recieve an 
                        argument for `selection`. Selection will
                        automatically be created with a split of 
                        train: {self._train_val_split[0]} 
                        and validation: {self._train_val_split[1]}""")
            self._train_selection, self._val_selection = self._infer_selections()

    def _split_selection(self, selection: List[int]) -> Tuple[List[int], List[int]]:
        """Split train selection into train/validation.

        Args:
            selection: Trainining selection to be split

        Returns:
            training selection, validation selection.
        """
        train_selection, val_selection = train_test_split(selection, 
                                                          train_size = self._train_val_split[0], 
                                                          test_size=self._train_val_split[1], 
                                                          random_state=self._rng)
        return train_selection, val_selection

    def _infer_selections(self) -> Tuple[List[int], List[int]]:
        """Automatically infer training and validation selections.

        Returns:
            Training selection, Validation selection
        """
        if self._use_ensemble_dataset:
            # We must iterate through the dataset paths and infer a train/val
            # selection for each.
            self._train_selection = []
            self._val_selection = []
            for dataset_path in self._dataset_args['path']:
                train_selection, val_selection = self._infer_selections_on_single_dataset(dataset_path)
                self._train_selection.append(train_selection)
                self._val_selection.append(val_selection)
        else:
            # Infer selection on a single dataset
            self._train_selection, self._val_selection = self._infer_selections_on_single_dataset(self._dataset_args['path'])

    def _infer_selections_on_single_dataset(self, 
                                            dataset_path: str) -> Tuple[List[int], List[int]]:
        """ Automatically infer training and validation selection on a single Dataset."""
        tmp_args = deepcopy(self._dataset_args)
        tmp_args['path'] = dataset_path
        tmp_dataset = self._construct_dataset(tmp_args)

        all_events = tmp_dataset._get_all_indices() # unshuffled list

        # Multiple lines to avoid one large
        all_events = pd.DataFrame(all_events).sample(frac=1, 
                                                    replace = False, 
                                                    random_state = self._rng)

        all_events = all_events.values.tolist() # shuffled list
        return self._split_selection(self, selection = all_events)

    def _create_dataset(self, 
                        selection : Union[List[int], List[List[int]]]) -> Union[EnsembleDataset, Dataset]:
        """Instantiate `dataset_reference`.

        Args:
            selection: The selected event id's. 

        Returns:
            A dataset, either an instance of `EnsembleDataset` or `Dataset`.
        """
        if self._use_ensemble_dataset:
            # Construct multiple datasets and pass to EnsembleDataset
            # At this point, we have checked that len(selection) == len(dataset_args['path'])
            datasets = []
            for dataset_idx in range(len(selection)):
                datasets.append(self._create_single_dataset(selection = selection[dataset_idx],
                                                  path = self._dataset_args['path'][dataset_idx]))

            dataset = EnsembleDataset(datasets)

        else:
            # Construct single dataset
            dataset = self._create_single_dataset(selection = selection,
                                                  path = self._dataset_args['path'])
        return dataset

    def _create_single_dataset(self,
                               selection : List[int],
                               path: str) -> Dataset:
        """Instantiate a single `Dataset`.

        Args:
            selection: A selection for a single dataset.
            path: path to a single dataset

        Returns:
            An instance of `Dataset`.
        """
        tmp_args = deepcopy(self._dataset_args)
        tmp_args['path'] = path
        tmp_args['selection'] = selection
        return self._dataset(**tmp_args)

One could consider to move the arguments selection and test_selection into dataset_args to lower the number of arguments to GraphNeTDataModule. One could also consider allowing these selection arguments to be file paths to .csv files that contain the selection.

@AMHermansen is this along the lines of what you were thinking of? Would it be compatible with the CLI and profiling?

AMHermansen commented 10 months ago

Hello @RasmusOrsoe I've now had some time to look at the proposed layout, and overall I agree with the many points your making. I also like the suggested functionality of the DataModule.

I have some minor suggestions.

I think it is necessary to have the option of giving different dataloader_kwargs for (train/val/test/predict). Since shuffle is often enabled for training but disabled for the remaining options, also I believe it adds minimal extra code, and we can make it so that in case only one set of dataloader_kwargs is passed then it will the kwargs which are passed to all dataloaders.

It would be simpler to make it work with the LightningCLI if the DataModule can read selections directly from a (.csv) file, this naturally might introduce some limitations about the exact format in the .csv (how many columns, is header included, delimiter, etc.). To make it easier to get the format working, it might be an idea to include a staticmethod of the DataModule selection_to_csv(selection, file_name) -> None (need to decided if selection should be pd.Series or np.ndarray). Which saves the selection in the correct format.

I also think it might make sense to have _create_dataset public, since there might be use-cases where and end user would want to get the dataset and use it outside a dataloader (Perhaps query benchmarking). (This would mostly be beneficial if we could somehow avoid have a dataset_reference argument and automatically infer the Dataset class. This might not be very scalable however, since there might be a future where we have multiple datasets, which rely on an SQLite backend for example)

RasmusOrsoe commented 10 months ago

I think it is necessary to have the option of giving different dataloader_kwargs for (train/val/test/predict). Since shuffle is often enabled for training but disabled for the remaining options, also I believe it adds minimal extra code, and we can make it so that in case only one set of dataloader_kwargs is passed then it will the kwargs which are passed to all dataloaders.

I think this is a good point. Besides the shuffle arg, I cannot on the top of my mind recall other important arguments that is not shared between the dataloaders. I would be much in favor of adding an argument like shuffle_test_dataloader = Optional[bool] = False than restructuring the dataloader_kwargs to accept multiple dictionaries of arguments.

It would be simpler to make it work with the LightningCLI if the DataModule can read selections directly from a (.csv) file, this naturally might introduce some limitations about the exact format in the .csv (how many columns, is header included, delimiter, etc.). To make it easier to get the format working, it might be an idea to include a staticmethod of the DataModule selection_to_csv(selection, file_name) -> None (need to decided if selection should be pd.Series or np.ndarray). Which saves the selection in the correct format.

I think that's a good idea. But the usage of this function would be limited to saving selections that the module either received or inferred after being instantiated.

I also think it might make sense to have _create_dataset public, since there might be use-cases where and end user would want to get the dataset and use it outside a dataloader (Perhaps query benchmarking). (This would mostly be beneficial if we could somehow avoid have a dataset_reference argument and automatically infer the Dataset class. This might not be very scalable however, since there might be a future where we have multiple datasets, which rely on an SQLite backend for example)

I'm not a big fan of exposing that method - the main point of the class is to create the dataloaders for you given a set of arguments. So intended usage (from what I'm proposing) is not to import it to create datasets. However, we could add properties to the class such that one could access the datasets that it created the dataloaders from. Something like:

from graphnet.data import GraphNeTDataModule

dm = GraphNetDataModule(..)
train_dataset = dm.train_dataset
valid_dataset = dm.validation_dataset
test_dataset = dm.test_dataset

I think it's perfectly reasonable to require the users to instantiate their datasets on their own if they need them outside this context.

@AMHermansen How does this sound?

AMHermansen commented 10 months ago

I think this is a good point. Besides the shuffle arg, I cannot on the top of my mind recall other important arguments that is not shared between the dataloaders. I would be much in favor of adding an argument like shuffle_test_dataloader = Optional[bool] = False than restructuring the dataloader_kwargs to accept multiple dictionaries of arguments.

I have previously been looking at other ways of doing "sequence length bucketing" and another way, than what is currently implemented in GraphNeT is to create a custom Sampler class, which is responsible for grouping similar length events together. Such a sampler would differ across Train/Val etc. since they are responsible for returning the correct indices for the Dataset obj. It is also sometimes possible to run with a slightly larger batch size on validation/test, since you're not storing the computation graph in vram. I think it would take minimal work to make it possible to change the input type from: ´Dict´ to List[Dict]. You could do something like

def __init__(self, ..., dataloader_kwargs: Union[Dict, List[Dict]):
    ...
    if isinstance(dataloader_kwargs, Dict):
        dataloader_kwargs = [dataloader_kwargs for _ in range(3)]  # Possible to change hardcoded 3 to clone to desired number (maybe 2, 3, or 4)
    # From now on dataloader_kwargs is a list containing dicts for train/val/test 
    ...

I think that's a good idea. But the usage of this function would be limited to saving selections that the module either received or inferred after being instantiated.

I thought it could be a staticmethod and you would be able to look at the data beforehand and make a bunch of selections (which might take some computation time) and when you found the event_nos which you like, it could be used like:

good_event_numbers = make_expensive_cuts(data)
GraphNeTDataModule.selection_to_csv(good_event_numbers, "/path/to/selection.csv")

And then later when you want to instantiate the DataModule you could do:

dm = GraphNeTDataModule(..., selection="/path/to/selection.csv")

The point of this, is that it becomes significantly easier when implementing a CLI, since the CLI kinda does the following by default:

dm_kwargs = get_dm_kwargs_from_config(config)
lightning_model_kwargs = get_model_kwargs_from_config(config)
trainer_kwargs = get_trainer_kwargs_from_config(config)

dm = SomeDataModule(**dm_kwargs)
model = SomeLightningModel(**lightning_model_kwargs)
trainer = Trainer(**trainer_kwargs)

trainer.fit(model, dm)

So the by far easiest option would be if the selection somehow could be stored in a format, that is yaml-config friendly.

I'm not a big fan of exposing that method - the main point of the class is to create the dataloaders for you given a set of arguments. So intended usage (from what I'm proposing) is not to import it to create datasets. However, we could add properties to the class such that one could access the datasets that it created the dataloaders from. Something like:

I have also thought more about this and I agree with you. My original idea was to make the DataModule a Dataset factory, but this is only really beneficial if we can automatically infer the Dataset Class from path. And even if we could that, there are better ways of making a factory pattern in python.

Let me know what you think.

RasmusOrsoe commented 10 months ago

I have previously been looking at other ways of doing "sequence length bucketing" and another way, than what is currently implemented in GraphNeT is to create a custom Sampler class, which is responsible for grouping similar length events together. Such a sampler would differ across Train/Val etc. since they are responsible for returning the correct indices for the Dataset obj. It is also sometimes possible to run with a slightly larger batch size on validation/test, since you're not storing the computation graph in vram. I think it would take minimal work to make it possible to change the input type from: ´Dict´ to List[Dict]. You could do something like

def __init__(self, ..., dataloader_kwargs: Union[Dict, List[Dict]):
    ...
    if isinstance(dataloader_kwargs, Dict):
        dataloader_kwargs = [dataloader_kwargs for _ in range(3)]  # Possible to change hardcoded 3 to clone to desired number (maybe 2, 3, or 4)
    # From now on dataloader_kwargs is a list containing dicts for train/val/test 
    ...

I think that's a valid point. So we should allow complete control of each dataloader. However, I'm not a fan of letting the user provide a list of dictionaries because the order of these is opaque and we won't be able to write good checks to make sure the right dict goes to the right dataloader. Instead, we could allow for a dict of dicts. kwargs['val_loader'] = {'batch'_size' : 1, ...} but I don't find that to be a particular clean solution either.

I think the least messy and most intuitive way would be two introduce three separate args: train_dataloader_kwargs: dict[str, any], validation_dataloader_kwargs: Optional[dict[str, any]] andtest_dataloader_kwargs: Optional[dict[str, any]]. If only train_dataloader_kwargs is given, we will use these settings for all dataloaders but forceshuffle to be False.

I thought it could be a staticmethod and you would be able to look at the data beforehand and make a bunch of selections (which might take some computation time) and when you found the event_nos which you like, it could be used like:

good_event_numbers = make_expensive_cuts(data)
GraphNeTDataModule.selection_to_csv(good_event_numbers, "/path/to/selection.csv")

I would find it more intuitive if that function was a separate thing and not bound to this module. Like:

from graphnet.training.utils import save_selection
good_event_numbers = make_expensive_cuts(data)
save_selection(good_event_numbers, "/path/to/selection.csv")

We could then use this function in GraphNeTDataModule.

So the by far easiest option would be if the selection somehow could be stored in a format, that is yaml-config friendly.

I trust a file path to a file that contains the selection will satisfy this?

AMHermansen commented 10 months ago

I think the least messy and most intuitive way would be two introduce three separate args: train_dataloader_kwargs: dict[str, any], validation_dataloader_kwargs: Optional[dict[str, any]] andtest_dataloader_kwargs: Optional[dict[str, any]]. If only train_dataloader_kwargs is given, we will use these settings for all dataloaders but forceshuffle to be False.

I can definitely see this work, we could consider to add a 4th argument common_dataloader_kwargs: Optional[Dict[str, Any]] and then any key not found in train/validation/test_dataloader_kwargs would default back to common_dataloader_kwargs, and maybe treat shuffle in a special way. But I'm also completely fine with limiting it to just the 3 options.

I would find it more intuitive if that function was a separate thing and not bound to this module. Like:

from graphnet.training.utils import save_selection
good_event_numbers = make_expensive_cuts(data)
save_selection(good_event_numbers, "/path/to/selection.csv")

We could then use this function in GraphNeTDataModule.

I don't have particularly strong opinions about where such a function is stored. To me it was just more natural to couple it to the DataModule, since in my mind, it would almost always be used in relation to that class. Either saving the found selection, or saving a selection for future use by the class.

I trust a file path to a file that contains the selection will satisfy this?

Yes exactly, my point was that a filepath to the selection would be a lot easier to work with, compared to a list of integers :-)

RasmusOrsoe commented 10 months ago

Alright. It looks like we now have a well-defined task. @samadpls are you still interested in working on this?

samadpls commented 10 months ago

Alright. It looks like we now have a well-defined task. @samadpls are you still interested in working on this?

Yes, I'm still interested 😋

RasmusOrsoe commented 10 months ago

@samadpls Great! I've assigned the issue to you. The thread should contain enough details to get you started. Let us know if you require anything!

samadpls commented 10 months ago

Thank you for assigning the task to me. I have read the thread and appreciate the detailed conversation. I have a request; it would be great if I could get an invitation to the Slack channel for quick questions or suggestions. I tried to login via the link mentioned in the readme, but access is restricted to specific organization domains

RasmusOrsoe commented 10 months ago

@samadpls you should have received an invitation now. Let me know if there are any other issues!

RasmusOrsoe commented 8 months ago

@samadpls as a follow-up to our conversation at the last dev meeting, here is a list of unit tests for your datamodule. Notice that this is pseudo-code that I did not actually run, so you might have to adjust it slightly.

from typing import Union, Dict, Any, List

from torch.utils.data import SequentialSampler
import pandas as pd
import sqlite3
from copy import deepcopy

from graphnet.constants import EXAMPLE_DATA_DIR
from graphnet.data.constants import FEATURES, TRUTH
from graphnet.data.dataset import SQLiteDataset, ParquetDataset
from graphnet.data.datamodule import GraphNeTDataModule

def extract_all_events_ids(file_path: str) -> List[int]:
    """Extract all available event ids."""
    if file_path.endswith('.parquet'):
        selection = pd.read_parquet(file_path)['event_no'].ravel().tolist()
    elif file_path.endswith('.db'):
        with sqlite3.connect(file_path) as conn:
            query = f'SELECT event_no FROM {dataset_kwargs["truth_table"]}'
            selection = pd.read_sql(query,conn)['event_no'].ravel().tolist()
    else:
        assert 1==2, f"file extension not accepted: {file_path.split('.')[-1]}"
    return selection

def single_dataset_without_selections(dataset_ref: Union[SQLiteDataset, ParquetDataset],
                                    dataset_kwargs: Dict[str, Any],
                                    dataloader_kwargs: Dict[str, Any]) -> None:
    """Test that default behavior of DataModule works as expected.

        No selection is given - DataModule should automatically extract
        all available events in the dataset and partition them into a 
        validation / training split.

        Only arguments to the training dataloader is given - DataModule
        should set validation dataloader settings equal to training dataloader
        settings, but with shuffle = False."""

    # Only training_dataloader args
    # Default values should be assigned to validation dataloader
    dm = GraphNeTDataModule(dataset_reference = dataset_ref,
                            dataset_args = dataset_kwargs,
                            train_dataloader_kwargs = dataloader_kwargs)

    train_dataloader = dm.train_dataloader
    val_dataloader = dm.val_dataloader

    try:
        # should fail because we provided no test selection
        test_dataloader = dm.test_dataloader
        assert 1 == 2, "Should not reach here"
    except:
        pass

    # validation loader should have shuffle = False by default
    assert isinstance(val_dataloader.sampler, SequentialSampler)

    # Should have identical batch_size
    assert val_dataloader.batch_size == train_dataloader.batch_size

    # Training dataloader should contain more batches
    assert len(train_dataloader) > len(val_dataloader)

    return

def single_dataset_with_selections(dataset_ref: Union[SQLiteDataset, ParquetDataset],
                                dataset_kwargs: Dict[str, Any],
                                dataloader_kwargs: Dict[str, Any]) -> None:

    """Test that selection functionality of DataModule behaves as expected."""

    # extract all events
    file_path = dataset_kwargs['path']
    selection = extract_all_events_ids(file_path = file_path)

    test_selection = selection[0:10]
    train_val_selection = selection[10:]

    # Only training_dataloader args
    # Default values should be assigned to validation dataloader
    dm = GraphNeTDataModule(dataset_reference = dataset_ref,
                            dataset_args = dataset_kwargs,
                            train_dataloader_kwargs = dataloader_kwargs,
                            selection = train_val_selection,
                            test_selection = test_selection)

    train_dataloader = dm.train_dataloader
    val_dataloader = dm.val_dataloader
    test_dataloader = dm.test_dataloader

    # Check that the training and validation dataloader contains
    # the same number of events as was given in the selection.
    assert len(train_dataloader.dataset) + len(val_dataloader.dataset) == len(train_val_selection)

    # Check that the number of events in the test dataset is equal to the
    # number of events given in the selection.
    assert len(test_dataloader.dataset) == len(test_selection)

    # Training dataloader should have more batches
    assert len(train_dataloader) > len(val_dataloader)

def test_dataloader_args(dataset_ref: Union[SQLiteDataset, ParquetDataset],
                        dataset_kwargs: Dict[str, Any],
                        dataloader_kwargs: Dict[str, Any]) -> None:
    """ Test that arguments to dataloaders are propagated correctly."""
    val_dataloader_kwargs = deepcopy(dataloader_kwargs)
    test_dataloader_kwargs = deepcopy(dataloader_kwargs)

    # Setting batch sizes to different values
    val_dataloader_kwargs['batch_size'] = 1
    test_dataloader_kwargs['batch_size'] = 2
    dataloader_kwargs['batch_size'] = 3

    dm = GraphNeTDataModule(dataset_reference = dataset_ref,
                            dataset_args = dataset_kwargs,
                            train_dataloader_kwargs = dataloader_kwargs,
                            val_dataloader_kwargs = val_dataloader_kwargs,
                            test_dataloader_kwargs = test_dataloader_kwargs)

    # Check that the resulting dataloaders have the right batch sizes
    assert dm.train_dataloader.batch_size == dataloader_kwargs['batch_size'] 
    assert dm.val_dataloader.batch_size == val_dataloader_kwargs['batch_size']
    assert dm.test_dataloader.batch_size == test_dataloader_kwargs['batch_size']

def test_ensemble_dataset_without_selections(dataset_ref: Union[SQLiteDataset, ParquetDataset],
                                            dataset_kwargs: Dict[str, Any],
                                            dataloader_kwargs: Dict[str, Any]) -> None:
    """Test ensemble dataset functionality without selections."""

    # Make dataloaders from single dataset
    dm_single = GraphNeTDataModule(dataset_reference = dataset_ref,
                            dataset_args = deepcopy(dataset_kwargs),
                            train_dataloader_kwargs = dataloader_kwargs)

    # Copy dataset path twice; mimmick ensemble dataset behavior
    ensemble_dataset_kwargs = deepcopy(dataset_kwargs)
    dataset_path = ensemble_dataset_kwargs['path']
    ensemble_dataset_kwargs['path'] = [dataset_path, dataset_path]

    # Create dataloaders from multiple datasets
    dm_ensemble = GraphNeTDataModule(dataset_reference = dataset_ref,
                            dataset_args = ensemble_dataset_kwargs,
                            train_dataloader_kwargs = dataloader_kwargs)

    # Test that the ensemble dataloaders contain more batches
    assert len(dm_single.train_dataloader) < len(dm_ensemble.train_dataloader)
    assert len(dm_single.val_dataloader) < len(dm_ensemble.val_dataloader)

def test_ensemble_dataset_with_selections(dataset_ref: Union[SQLiteDataset, ParquetDataset],
                                            dataset_kwargs: Dict[str, Any],
                                            dataloader_kwargs: Dict[str, Any]) -> None:
    """Test ensemble dataset functionality with selections."""
    # extract all events
    file_path = dataset_kwargs['path']
    selection = extract_all_events_ids(file_path = file_path)

    # Copy dataset path twice; mimmick ensemble dataset behavior
    ensemble_dataset_kwargs = deepcopy(dataset_kwargs)
    dataset_path = ensemble_dataset_kwargs['path']
    ensemble_dataset_kwargs['path'] = [dataset_path, dataset_path]

    # pass two datasets but only one selection; should fail:
    try:
        _ = GraphNeTDataModule(dataset_reference = dataset_ref,
                            dataset_args = ensemble_dataset_kwargs,
                            train_dataloader_kwargs = dataloader_kwargs,
                            selection = selection)
        assert 1 == 2, "Should not reach here."
    except:
        pass

    # Pass two datasets and two selections; should work:
    selection_1 = selection[0:20]
    selection_2 = selection[0:10]
    dm = GraphNeTDataModule(dataset_reference = dataset_ref,
                            dataset_args = ensemble_dataset_kwargs,
                            train_dataloader_kwargs = dataloader_kwargs,
                            selection = [selection_1, selection_2])
    n_events_in_dataloaders = len(dm.training_dataloader.dataset) + len(dm.val_dataloader.dataset)

    # Check that the number of events in train/val match
    assert n_events_in_dataloaders == len(selection_1) + len(selection_2)

    # Pass two datasets, two selections and two test selections; should work
    dm2 = GraphNeTDataModule(dataset_reference = dataset_ref,
                            dataset_args = ensemble_dataset_kwargs,
                            train_dataloader_kwargs = dataloader_kwargs,
                            selection = [selection, selection]
                            test_selection = [selection_1, selection_2])

    # Check that the number of events in test dataloaders are correct.
    n_events_in_test_dataloaders = len(dm2.test_dataloader.dataset)
    assert n_events_in_test_dataloaders == len(selection_1) + len(selection_2)
if __name__ == '__main__':

    for dataset_ref in [SQLiteDataset, ParquetDataset]:
        # Grab public dataset paths
        if isinstance(dataset_ref, SQLiteDataset):
            data_path = f"{EXAMPLE_DATA_DIR}/sqlite/prometheus/prometheus-events.db"
        elif isinstance(dataset_ref, ParquetDataset):
            data_path = f"{EXAMPLE_DATA_DIR}/parquet/prometheus/prometheus-events.parquet"

        # Setup basic inputs; can be altered by individual tests
        dataset_kwargs = {'truth_table': 'mc_truth',
                          'pulsemaps': 'total',
                          'truth': TRUTH.PROMETHEUS,
                          'features': FEATURES.PROMETHEUS,
                          'path': data_path}

        dataloader_kwargs = {'batch_size': 2,
                             'num_workers': 1}

        # Run each test given arguments
        single_dataset_without_selections(dataset_ref = dataset_ref,
                                dataset_kwargs = dataset_kwargs,
                                dataloader_kwargs = dataloader_kwargs)

        single_dataset_with_selections(dataset_ref = dataset_ref,
                                dataset_kwargs = dataset_kwargs,
                                dataloader_kwargs = dataloader_kwargs)

        test_dataloader_args(dataset_ref = dataset_ref,
                                dataset_kwargs = dataset_kwargs,
                                dataloader_kwargs = dataloader_kwargs)

        test_ensemble_dataset_without_selections(dataset_ref = dataset_ref,
                                dataset_kwargs = dataset_kwargs,
                                dataloader_kwargs = dataloader_kwargs)

        test_ensemble_dataset_with_selections(dataset_ref = dataset_ref,
                                dataset_kwargs = dataset_kwargs,
                                dataloader_kwargs = dataloader_kwargs)