graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
90 stars 92 forks source link

Should we change Dataset to adaptively create the correct subclass depending on the input path? #643

Open AMHermansen opened 10 months ago

AMHermansen commented 10 months ago

At the meeting today I mentioned that we could use hooks from __new__ and __init_subclass__ to automatically infer the desired dataset-subclass.

Currently we have two Datasets which are being used SQLiteDataset and ParquetDataset, and they both take the same inputs arguments.

I think we could overwrite __new__ and __init_subclass__ in Dataset in such a way that __init_subclass__ would make a "subclass registry", which connects file-extensions to implemented datasets, and then __new__ would look up the subclass registry and find the correct subclass to instantiate.

A "simple" illustration of how this would look like

from typing import Iterable, Union

class A:
    _subclass_registry = {}
    def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None, *, path: str):
        self.path = path
        self.arg1 = arg1
        self.arg2 = arg2
        self.kwarg1 = kwarg1
        self.kwarg2 = kwarg2

    def __init_subclass__(cls, file_extensions: Union[str, Iterable[str]], **kwargs):
        if isinstance(file_extensions, str):
            file_extensions = [file_extensions]
        for ext in file_extensions:
            if ext in cls._subclass_registry:
                raise ValueError(f"Duplicate file extension: {ext}")
            A._subclass_registry[ext] = cls
        super().__init_subclass__(**kwargs)

    def __new__(cls, *args, **kwargs):
        path = kwargs["path"]
        file_extension = path.split(".")[-1]
        subclass = cls._subclass_registry.get(file_extension, None)
        if subclass is None:
            raise ValueError(f"Unknown file extension: {file_extension}")
        return object.__new__(subclass)

class B(A, file_extensions="ext1"):
    def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None, *, path: str):
        super().__init__(arg1, arg2, kwarg1=kwarg1, kwarg2=kwarg2, path=path)
        print(f"Created B instance with path: {self.path}")

class C(A, file_extensions=["ext2", "ext3"]):
    def __init__(self, arg1, arg2, kwarg1=None, kwarg2=None, *, path: str):
        super().__init__(arg1, arg2, kwarg1=kwarg1, kwarg2=kwarg2, path=path)
        print(f"Created C instance with path: {self.path}")

if __name__ == "__main__":
    a = A(1, 2, path="file.ext1")  # Creates object from class B
    b = A(3, 4, path="file.ext2")  # Creates object from class C
    c = A(5, 6, path="file.ext3")  # Creates object from class C
    print(f"{type(a)=}")
    print(f"{type(b)=}")
    print(f"{type(c)=}")

Pros:

Cons: