graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
85 stars 86 forks source link

Implement class to crop pulsemaps to maximum length #648

Open AMHermansen opened 7 months ago

AMHermansen commented 7 months ago

Is your feature request related to a problem? Please describe. With #558 we now have better control over how a pulsemap is processed. From the Kaggle competition it became apparent that many of the top scoring models simply cropped the number of pulses to some fixed number, to reduce the impact of the n^2 term from Self-Attention components.

While their primary way to select pulses was to simply select the first n pulses, I believe it might be interesting to look into other methods of selecting pulses. (Randomly, sorted by charge, sorted by probability of real signal, farthest point sampling etc.)

Describe the solution you'd like

To avoid having to implement many Node Definitions I think it might make sense to make a common class for all cropped nodes

class CroppedNodes(NodeDefinition):
    def __init__(self, max_pulses: int, cropping_method: Callable) -> None:
        super().__init__()
        self.max_pulses = max_pulses
        self._cropping_method = cropping_method

    def _construct_nodes(self, x: torch.Tensor) -> Data:
        x = self._cropping_method(x, self.max_pulses)
        return Data(x=x)

Such a structure would also allow to easier re-use the copping methods in other node definitions. (Maybe you want to crop after calculating summary nodes per dom, to make sure you do not get an event which triggered 5k doms.

Describe alternatives you've considered We could of course just implement each cropping algorithm as a subclass of a common CroppedNodes class and have the logic restricted to each subclass. But I think the cropping logic is general enough that there is merit to have it as a separate component.

RasmusOrsoe commented 6 months ago

Hey @AMHermansen!

I think it is a great idea to allow for such functionality in GraphDefinition. "Cropping" pulsemaps is essentially just sub-sampling of the available pulses. I would suggest to add this as an independent sub-module of GraphDefinition, so on the user side it could look like:

from graphnet.models.graphs import GraphDefinition
graph_definition = GraphDefinition(detector = detector,
                                   node_definition = node_definition,
                                   edge_definition = edge_definition,
                                   sampler = sampler)

in the forward pass of GraphDefinition we could add early on (perhaps just after the basic checks) a line like so:

if self.sampler is not None:
    subsample_idx = self.sampler(input_features = input_features,
                                 input_feature_names = input_feature_names)
    input_features = input_features[subsample_idx,:]

That would mean that the sampling would be independent of what users would like to do with the pulses.

Here's a quick take on what the sampling module could look like:

from typing import List
from abc import abstractmethod

from graphnet.models import Model
from graphnet.utilities.decorators import final
import numpy as np

class Sampler(Model):
    """Base class for sub-sampling rows in single events."""

    def __init__(self) -> None:
        """Construct `Sampler`."""
        # Base class constructor
        super().__init__(name=__name__, class_name=self.__class__.__name__)

    @final
    def forward(self, 
                input_features: np.ndarray, 
                input_feature_names: List[str]) -> List[bool]:
        """Produce subsampling indices."""
        mask = self._create_subsample_indices(input_features = input_features,
                                              input_feature_names = input_feature_names)
        self._validate_mask(mask = mask,
                            input_features = input_features)
        return mask

    def _validate_mask(self, 
                       mask: List[bool], 
                       input_features: np.ndarray) -> None:
        """Check that the output of the custom mask method meets requirements."""
        try:
            assert isinstance(mask, list)
        except AssertionError as e:
            self.error(f"Subsampling indices must be a list of bools. 
                       Got {type(mask)}.")
            raise e

        try:
            assert len(mask) == len(input_features)
        except AssertionError as e:
            self.error(f"Subsampling method did not return a bool for reach row.")
            raise e
        return

    @abstractmethod
    def _create_subsample_indices(self,
                                  input_features: np.ndarray,
                                  input_feature_names: List[str]) -> List[int]:
        """Create a list of integers that defines which rows in `input_features are kept.`

            Example:
            input_features = [[1,2,3],
                            [5,5,5],
                            [0,0,1],]
            input_feature_names = ['dom_x', 'dom_y', 'dom_z']

            Suppose we wrote logic that produced the following 
            mask = [0,1]

            This would mean that the corresponding subsampled rows would be:

            input_features = [[1,2,3],
                              [5,5,5]]"""
        raise NotImplementedError

So a Sampler that would randomly subsample events exceeding some limit could look like so:

class RandomMaxSampler(Sampler):
    """Randomly sample events exceeding a maximum length."""

    def __init__(self, 
                 max_event_size: int,
                 seed: int = 42):
        """Randomly sample available pulses if event is larger than `max_event_size`.

        Args:
            max_event_size: The maximum number of pulses in the event. 
                            Events with more pulses than this will be randomly sampled.
            seed: seed used for random sampling. Defaults to 42.

        """
        self._max_size = max_event_size
        self._seed = seed

    def _create_subsample_indices(self,
                                  input_features: np.ndarray,
                                  input_feature_names: List[str]) -> List[int]:
        if input_features.shape[0] > self._max_size:
            mask  = np.random.choice(input_features, self._max_size, seed = self._seed)
        else:
            mask = np.arange(0, len(input_features))
        return mask