graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
85 stars 86 forks source link

Refactor DataConverter such that it generalizes to other experiments #634

Closed RasmusOrsoe closed 5 months ago

RasmusOrsoe commented 7 months ago

In our current implementation of DataConverter, we assume that the input is i3 files and provide no way for a user to overwrite that assumption. This makes it difficult for other experiments to integrate their file reading and extraction for data conversion. This has been a challenge for @kaareendrup and recently Jorge from km3net. I suggest that we refactor DataConverter under the philosophy:

Converting an arbitrary file type to a graphnet-supported backend should be easy.

To achieve this, I think we can refactor DataConverter into modularized parts:

A quick diagram illustrates the dataflow below: image

This refactor would lower the "technical price" of implementing new input file types/sources data into GraphNeT significantly, as it will only require knowledge on how to open a file and extract the quantities of interest. "What you get for your trouble" is multiprocessing, consistent event_no assignment and the ability to save it to any of the implemented FileSaveMethods. I think that's a good bargain.

Below is some pseudo-code on how we could structure the modules:

from typing import List, Union, OrderedDict
from abc import abstractmethod
import glob

class GraphNeTFileReader:

    @abstractmethod
    def __call__(self, file_path: str) -> OrderedDict:
        """ This function should contain the logic for opening and applying
        extractors to a single file. Result should be returned as a Dict."""

    @property
    def accepted_file_extensions(self):
        return self._accepted_file_extensions

    @property
    def accepted_extractors(self):
        return self._accepted_extractors

    def find_files(self, path: Union[str, List[str]]):
        """ A method for identifying input files in a directory given
            by the user. May be overwritten in custom implementations.

            Will be called by DataConverter if a directory is given as
            input."""
        files = []
        for accepted_file_extension in self.accepted_file_extensions:
            files.extend(glob.glob(path + f'/*{accepted_file_extension}'))

        # Check that files are OK.
        self._validate_files(files)
        return files

    @final
    def set_extractors(self, extractors: List[GraphNeTExtractors]):
        """ Called by DataConverter to set extractors."""
        self._validate_extractors(extractors)
        self._extractors = extractors

    @final
    def _validate_extractors(self):
        # assert given extractors are accepted

    @final
    def validate_files(self):
        """ Called by DataConverter to validate input file paths."""
        # assert file extensions are accepted
import os

class GraphNeTFileSaveMethod:

    @abstractmethod
    def save_file(self, data: OrderedDict, output_file_path) -> None:
        """ Logic for saving `Data` to a specific file format at 
            `output_file_path`"""
        return

    @final
    def __call__(self, data: OrderedDict, file_name: str, out_dir: str) -> None:
        output_file_path = os.join(out_dir, file_name, self.file_extension)
        self._save_file(data = data, output_file_path = output_file_path)
        return

    @property
    def file_extension(self) -> str:
        return self._file_extension
class DataConverter:

    def __init__(file_reader: GraphNeTFileReader,
                 save_method: GraphNeTFileSaveMethod,
                 extractors: List[GraphNeTExtractor],
                 num_workers: int = 1) -> None:

        # Member Variable Assignment
        self._file_reader = file_reader
        self._save_method = save_method
        self._num_workers = num_workers

        # Set Extractors. Will throw error if extractors are incompatible
        # with reader.
        self._file_reader.set_extractors(extractors)

    @final
    def __call__(self, input_dir: Union[str, List[str]], output_dir: str) -> None:
        # Get the file reader to produce a list of input files
        # in the directory
        input_files = self._file_reader.find_files(input_dir)
        self._launch_jobs(input_files = input_files, output_dir = output_dir)

    @final
    def _launch_jobs(self, input_files: Union[List[str], List[I3FileSets]]) -> None:
        """ Multi Processing Logic.

            Spawns worker pool, 
            distributes the input files evenly across workers.
            declare event_no as globally accessible variable across workers.
            starts jobs.

            Will call process_file in parallel."""

    @final
    def _process_file(self, file_path: str):
        """ This function is called in parallel"""
        # Read and apply extractors
        data = self._file_reader(file_path = file_path)

        # Assign event_no's to each event in data
        data = self._assign_event_no(data = data)

        # Create output file name
        output_file_name = self._generate_output_file_name(file_path = file_path)

        # Apply save method
        self._save_method(data = data, 
                          file_name = output_file_name)