graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
94 stars 94 forks source link

Use unsupervised learning to learn representations in data and Monte Carlo #527

Open asogaard opened 1 year ago

asogaard commented 1 year ago

Suggested steps:

AMHermansen commented 11 months ago

This is the current layout of the MAE that I've been working. The general MAE architecture is based of Masked Autoencoders Are Scalable Vision Learners GraphNeT_BEiT_Style_MAE drawio

The aggregation loss is currently computed by predicting logits for dom PID, and the duplicating the logits for each masked pulse. The motivation for this novel loss function, is to have a permutation invariant loss function, since during prototyping it was observed that even with Positional Encodings the model wasn't able to distinguish nodes to a satisfiable level.

It is possible to add a GNN-decoder. Which will take in the features learned by the encoder and a Mask-Token (learned parameters in the latent space of the encoder.) The Decoder will then try to do a more node-level reconstruction. It should also be possible to "tokenize" the PIDs for the masked DOM's and then only do reconstruction non-positional variables. The idea behind this, is that the "predicted" time should very much depend on where the model suspects the DOM was located, but since it might confuse DOM's with one-another I think we might get better results by providing the true-doms and then tasking it with finding the time/charge for the DOM. Alternative loss functions for this part could be ChamferDistance see Point-M2AE eq.2.

This diagram also illustrates why I think Aggregation might be better suited as a separate module, since we might need to extract both the graph (to add mask-tokens) and then do aggregation separately.

I'm open to inputs on the architecture. Currently my model is implemented using encoder-transformers as the GNN-encoder/decoder.

RasmusOrsoe commented 11 months ago

Hey @AMHermansen I've now taken a deeper look at this, and I must admit I still struggle to follow your diagram. I think that your diagram depicts the usual two-stage approach for using autoencoders where one first trains the encoder-decoder pair in a unsupervised setting, and then secondly finetune one of components of the encoder-decoder pair on a supervised learning task. Assuming this, I think the intended workflow is straight forward:

The first stage pre-trains the encoder-decoder pair according to some scheme defined in an AutoEncoding task, and the encoder and decoder is saved separately for later use.

from graphnet.models import AutoEncoder
from graphnet.models.gnn import DynEdge
from graphnet.models.detector import IceCube86
from graphnet.models.graphs import KNNGraph
from graphnet.models.task import AutoEncoding

# Define autoencoder
model = AutoEncoder(graph_definition = KNNGraph(),
                    detector = IceCube86()
                    encoder = DynEdge(..),
                    decoder = DynEdge(..),
                    task = AutoEncoding()
                    )

# Train autoencoder pair (first stage)
model.fit(..)

# Save encoder and decoder seperately
model.save_config(path = '..') # Saves the config for encoder and decoder seperately
model.save_state_dict(path = '')# Saves the state_dict for encoder and decoder seperately

In a second session, one can load in a pre-trained component of the encoder-decoder pair and fine tune it on a physics task:

 # .. In a different session (2nd stage)

from graphnet.models.gnn import DynEdge
from graphnet.models import StandardModel
from graphnet.models.detector import IceCube86
from graphnet.models.graphs import KNNGraph
from graphnet.models.task import EnergyReconstruction

pretrained_model = DynEdge.from_config('encoder_config.yml')
pretrained_model.load_state_dict('encoder_state_dict')

model = StandardModel(graph_definition = KNNGraph(),
                    detector = IceCube86()
                    gnn = pretrained_model,
                    task = EnergyReconstruction())

model.fit(..)

results = model.predict_as_dataframe(...)

# Save finetuned model
model.save_config(path = '..') 
model.save_state_dict(path = '')

I think the technical work required to achieve this syntax is reasonable. @AMHermansen Could you elaborate on how your current approach differs from this?

AMHermansen commented 11 months ago

Hey @AMHermansen I've now taken a deeper look at this, and I must admit I still struggle to follow your diagram. I think that your diagram depicts the usual two-stage approach for using autoencoders where one first trains the encoder-decoder pair in a unsupervised setting, and then secondly finetune one of components of the encoder-decoder pair on a supervised learning task. Assuming this, I think the intended workflow is straight forward:

I would first like to clarify that the diagram describes a masked-autoencoder and not a regular autoencoder. The main difference between the two architectures is found in their general purpose. Masked-autoencoders have been shown to provide promising self-surpervised learning tasks in a multitude of machine learning areas within (NLP and CV), and many State-of-the-Art models for various benchmarks seem to use a masked-autoencoder, as a pretraining step. See e.g. SOTA. Autoencoders are primarily used to take an input and compress it down, to a smaller latent space than the original space.

The first stage pre-trains the encoder-decoder pair according to some scheme defined in an AutoEncoding task, and the encoder and decoder is saved separately for later use.

from graphnet.models import AutoEncoder
from graphnet.models.gnn import DynEdge
from graphnet.models.detector import IceCube86
from graphnet.models.graphs import KNNGraph
from graphnet.models.task import AutoEncoding

# Define autoencoder
model = AutoEncoder(graph_definition = KNNGraph(),
                    detector = IceCube86()
                    encoder = DynEdge(..),
                    decoder = DynEdge(..),
                    task = AutoEncoding()
                    )

# Train autoencoder pair (first stage)
model.fit(..)

# Save encoder and decoder seperately
model.save_config(path = '..') # Saves the config for encoder and decoder seperately
model.save_state_dict(path = '')# Saves the state_dict for encoder and decoder seperately

I think this is somewhat close to what I'm currently working with. It is not clear to me, where you intend for the masking procedure to take place in the above pseudo code. I'm currently doing it in a specially defined NodeDefinition, since it feels more natural to have it as a preprocessing step.

A thing which is not clear in the syntax above, is the requirement that the encoder module is restricted to have been instatiated with aggregation=None, since it needs to output a graph-like structure for the input to the decoder. To get the summary loss a new module is added, which is responsible for aggregation the graph outputted from the encoder. Maybe I lack creativity, but I do not see any elegant way around this. The current architecture requires the ability to "intercept" the graph outputted by the encoder and the pass it through 2 different modules.

The only real alternatives I see is to either have GNN output a dictionary containing the final graphstate and the aggregated state. This would then require that tasks "grap" the right entry from the output-dict depending on what they expect as input, or make GNN have to methods learn_features and aggregate and then change forward to:

class GNN:
    ...
    def forward(data: Data):
        features = self.learn_features(data)
        ouput = self.aggregate(features)
        return output

In a second session, one can load in a pre-trained component of the encoder-decoder pair and fine tune it on a physics task:

 # .. In a different session (2nd stage)

from graphnet.models.gnn import DynEdge
from graphnet.models import StandardModel
from graphnet.models.detector import IceCube86
from graphnet.models.graphs import KNNGraph
from graphnet.models.task import EnergyReconstruction

pretrained_model = DynEdge.from_config('encoder_config.yml')
pretrained_model.load_state_dict('encoder_state_dict')

model = StandardModel(graph_definition = KNNGraph(),
                    detector = IceCube86()
                    gnn = pretrained_model,
                    task = EnergyReconstruction())

model.fit(..)

results = model.predict_as_dataframe(...)

# Save finetuned model
model.save_config(path = '..') 
model.save_state_dict(path = '')

Here I believe you'd run into another problem, since the pretrained_model would not have any kind of aggregation making it unsuitable for most of the tasks within GraphNeT. While you could instantiate it with an aggregation components, and then manually load the weights for the DynamicalEdgeConvolution, and then randomly instantiate the post processing layers, this however seems like a rather crude way of doing it.

My current implementation is mostly done in native pytorch_lightning and isn't very modularized, since I'm only working with transformer based architectures. To give a quick overview the pseudo code would look something like:

class MAE(Model):
    def __init__(self, encoder, decoder, aggregator, ...):
        self.encoder = encoder
        self.decoder = decoder
        self.aggregator = aggregator
        ...
    def encode(self, data: Data):
        masked_data = self._apply_masking(data)
        encoded_graph = self.encoder(data)
        return encoded_graph

    def decode(self, encoded_graph: Data, data: Data):
        out = self.aggregator(encoded_graph)
        decode_graph = self._insert_mask_tokens(encoded_graph, data)  # Inserts mask tokens and restores original order.
        decode_graph += self.positional_encoder(decode_graph)  # Insert positional information into tokens, to make all mask-tokens slightly different
        decode_graph = self.decoder(decode_graph)
        decoded_tokens = self.select_mask_tokens(decode_graph, data)  # Selects the mask tokens.
        return out, decoded_tokens

    def compute_loss(self, out, decoded_tokens, data):  
        summary_loss = self.summary_task(out, data)  
        ae_loss = self.ae_loss(decoded_tokens, data)
        return self.summary_weight * summary_loss + ae_loss

    def training_step(self, data: Data):
        encoded_graph = self.encode(data)
        out, decoded_tokens = self.decode(encoded_graph, data)
        loss = self.compute_loss(out, decoded_tokens, data)
        return {"loss": loss}
    ...

I hope this clarifies my previous message.

RasmusOrsoe commented 11 months ago

@AMHermansen thanks for the details! Before I respond in full, could you help clear up one last thing:

To get the summary loss a new module is added, which is responsible for aggregation the graph outputted from the encoder.

What is the difference between the purposes of "aggregation loss" and "node level loss"? Are both quantities used in the pre-training stage, or does "node level loss" represent the pre-training stage and "aggregation loss" is loss calculated in the 2nd stage where one fine-tunes the pre-trained component on a graph-level physics task?

AMHermansen commented 11 months ago

@AMHermansen thanks for the details! Before I respond in full, could you help clear up one last thing:

To get the summary loss a new module is added, which is responsible for aggregation the graph outputted from the encoder.

What is the difference between the purposes of "aggregation loss" and "node level loss"? Are both quantities used in the pre-training stage, or does "node level loss" represent the pre-training stage and "aggregation loss" is loss calculated in the 2nd stage where one fine-tunes the pre-trained component on a graph-level physics task?

Yes both the "aggregation loss" and the "node level loss" are used during pretraining.

The idea about using these two losses together is inspired by the BERT-paper (which to my knowledge is the first time a masked-autoencoder is used for pretraining) from natural language processing, where the model is given two sentences and some of the words in each sentence is masked. Then the model is tasked with both predicting the masked words and predicting if the two sentences are a "continuation" of one-another.

RasmusOrsoe commented 11 months ago

Alright @AMHermansen - sorry for the delay on this. I now have some pseudo-code ready.

I arrived at this pseudo-code by starting out with a series on considerations. First, I think we should insist that the user experience is intuitive and modular, such that auto encoding functionality in graphnet plays well with existing modules, whether that would be loss functions, model architectures, graph definitions etc. Secondly, I think it's important that when we implement new machine learning paradigms, we do it in a way that is true to how they are usually perceived, hopefully increasing the readability for new graphnet users. Thirdly, I think the implementations should aim to be flexible enough for users to apply autoencoding in graphnet to most usecases; and the few that it does not should be intuitive to implement.

These three considerations lead me to answers to a couple of the points you raised above:

  1. The syntax should be similar to what I showed earlier.
  2. Masking should happen in a forward pass in the autoencoder because that is where people expect it to be.
  3. Masking should be modular; i.e. the implementation should allow for an argument for some kind of callable that delivers the masking.
  4. The implementation should allow for saving state_dicts and model_configs for the encoder/decoder pair independently.
  5. We should introduce a Model method that allows for a partial state_dict load.
  6. It is OK that autoencoders only work on autoencoding-tasks.

I think "the best approach" is the approach that achieves these goals with the smallest implementation overhead and that does not repeat code. After some experimentation, I found that a nice way of doing this would be to first introduce a small refactor of StandardModel that moves boiler-plate code into a subclass currently called EasyModel. This class makes no assumption on number of model architectures or what they are, and it has four abstract methods: compute_loss, forward, shared_step, validate_tasks. It delivers the easy syntax our users are relying on; model.fit, model.predict_as_dataframe, etc. Our current StandardModel becomes a subclass of EasyModel with a specific implementation of the abstract methods that is aimed for single-architecture training. Pseudo-code below:

from collections import OrderedDict
from typing import Any, Dict, List, Optional, Union, Type

import numpy as np
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch import Tensor
from torch.nn import ModuleList
from torch.optim import Adam
from torch.utils.data import DataLoader, SequentialSampler
from torch_geometric.data import Data
import pandas as pd
from pytorch_lightning.loggers import Logger as LightningLogger

from graphnet.training.callbacks import ProgressBar
from graphnet.models.graphs import GraphDefinition
from graphnet.models.gnn.gnn import GNN
from graphnet.models.model import Model
from graphnet.models.task import StandardLearnedTask

class EasyModel(Model):
    """A Suggested Model format for GraphNeT models.

    This class delivers simple user syntax for powerful deep learning 
    techniques by chains together the different elements of a complete 
    model (detector read-in, GNN backbone, and task-specific read-outs).
    """

    def __init__(
        self,
        *,
        graph_definition: GraphDefinition,
        tasks: Union[StandardLearnedTask, List[StandardLearnedTask]],
        optimizer_class: Type[torch.optim.Optimizer] = Adam,
        optimizer_kwargs: Optional[Dict] = None,
        scheduler_class: Optional[type] = None,
        scheduler_kwargs: Optional[Dict] = None,
        scheduler_config: Optional[Dict] = None,
    ) -> None:
        """Construct `StandardModel`."""
        # Base class constructor
        super().__init__(name=__name__, class_name=self.__class__.__name__)

        # Check(s)
        if not isinstance(tasks, (list, tuple)):
            tasks = [tasks]
        self.validate_tasks()

        assert isinstance(graph_definition, GraphDefinition)

        # Member variable(s)
        self._graph_definition = graph_definition
        self._tasks = ModuleList(tasks)
        self._optimizer_class = optimizer_class
        self._optimizer_kwargs = optimizer_kwargs or dict()
        self._scheduler_class = scheduler_class
        self._scheduler_kwargs = scheduler_kwargs or dict()
        self._scheduler_config = scheduler_config or dict()

    @abstractmethod
    def compute_loss(
        self, preds: Tensor, data: List[Data], verbose: bool = False
    ) -> Tensor:
        """Compute and sum losses across tasks."""

        raise NotImplementedError(f"Subclasses of `EasyModel` must implement this method`")

    @abstractmethod
    def forward(
        self, data: Union[Data, List[Data]]
    ) -> List[Union[Tensor, Data]]:
        """Forward pass, chaining model components."""

        raise NotImplementedError(f"Subclasses of `EasyModel` must implement this method`")

    @abstractmethod
    def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
        """Perform shared step.

        Applies the forward pass and the following loss calculation, shared
        between the training and validation step.
        """

        raise NotImplementedError(f"Subclasses of `EasyModel` must implement this method`")

    @abstractmethod
    def validate_tasks(self):
        """ """
        raise NotImplementedError(f"Subclasses of `EasyModel` must implement this method`")

    @staticmethod
    def _construct_trainer(
        max_epochs: int = 10,
        gpus: Optional[Union[List[int], int]] = None,
        callbacks: Optional[List[Callback]] = None,
        logger: Optional[LightningLogger] = None,
        log_every_n_steps: int = 1,
        gradient_clip_val: Optional[float] = None,
        distribution_strategy: Optional[str] = "ddp",
        **trainer_kwargs: Any,
    ) -> Trainer:
        if gpus:
            accelerator = "gpu"
            devices = gpus
        else:
            accelerator = "cpu"
            devices = 1

        trainer = Trainer(
            accelerator=accelerator,
            devices=devices,
            max_epochs=max_epochs,
            callbacks=callbacks,
            log_every_n_steps=log_every_n_steps,
            logger=logger,
            gradient_clip_val=gradient_clip_val,
            strategy=distribution_strategy,
            **trainer_kwargs,
        )

        return trainer

    def fit(
        self,
        train_dataloader: DataLoader,
        val_dataloader: Optional[DataLoader] = None,
        *,
        max_epochs: int = 10,
        early_stopping_patience: int = 5,
        gpus: Optional[Union[List[int], int]] = None,
        callbacks: Optional[List[Callback]] = None,
        ckpt_path: Optional[str] = None,
        logger: Optional[LightningLogger] = None,
        log_every_n_steps: int = 1,
        gradient_clip_val: Optional[float] = None,
        distribution_strategy: Optional[str] = "ddp",
        **trainer_kwargs: Any,
    ) -> None:
        """Fit `StandardModel` using `pytorch_lightning.Trainer`."""
        # Checks
        if callbacks is None:
            # We create the bare-minimum callbacks for you.
            callbacks = self._create_default_callbacks(
                val_dataloader=val_dataloader,
                early_stopping_patience=early_stopping_patience,
            )
            self.debug("No Callbacks specified. Default callbacks added.")
        else:
            # You are on your own!
            self.debug("Initializing training with user-provided callbacks.")
            pass
        self._print_callbacks(callbacks)
        has_early_stopping = self._contains_callback(callbacks, EarlyStopping)
        has_model_checkpoint = self._contains_callback(
            callbacks, ModelCheckpoint
        )

        if (has_early_stopping) & (has_model_checkpoint is False):
            self.warning(
                """No ModelCheckpoint found in callbacks. Best-fit model will not automatically be loaded after training!"""
            )

        self.train(mode=True)
        trainer = self._construct_trainer(
            max_epochs=max_epochs,
            gpus=gpus,
            callbacks=callbacks,
            logger=logger,
            log_every_n_steps=log_every_n_steps,
            gradient_clip_val=gradient_clip_val,
            distribution_strategy=distribution_strategy,
            **trainer_kwargs,
        )

        try:
            trainer.fit(
                self, train_dataloader, val_dataloader, ckpt_path=ckpt_path
            )
        except KeyboardInterrupt:
            self.warning("[ctrl+c] Exiting gracefully.")
            pass

        # Load weights from best-fit model after training if possible
        if has_early_stopping & has_model_checkpoint:
            for callback in callbacks:
                if isinstance(callback, ModelCheckpoint):
                    checkpoint_callback = callback
            self.load_state_dict(
                torch.load(checkpoint_callback.best_model_path)["state_dict"]
            )
            self.info("Best-fit weights from EarlyStopping loaded.")

    def _print_callbacks(self, callbacks: List[Callback]) -> None:
        callback_names = []
        for cbck in callbacks:
            callback_names.append(cbck.__class__.__name__)
        self.info(
            f"Training initiated with callbacks: {', '.join(callback_names)}"
        )

    def _contains_callback(
        self, callbacks: List[Callback], callback: Callback
    ) -> bool:
        """Check if `callback` is in `callbacks`."""
        for cbck in callbacks:
            if isinstance(cbck, callback):
                return True
        return False

    @property
    def target_labels(self) -> List[str]:
        """Return target label."""
        return [label for task in self._tasks for label in task._target_labels]

    @property
    def prediction_labels(self) -> List[str]:
        """Return prediction labels."""
        return [
            label for task in self._tasks for label in task._prediction_labels
        ]

    def configure_optimizers(self) -> Dict[str, Any]:
        """Configure the model's optimizer(s)."""
        optimizer = self._optimizer_class(
            self.parameters(), **self._optimizer_kwargs
        )
        config = {
            "optimizer": optimizer,
        }
        if self._scheduler_class is not None:
            scheduler = self._scheduler_class(
                optimizer, **self._scheduler_kwargs
            )
            config.update(
                {
                    "lr_scheduler": {
                        "scheduler": scheduler,
                        **self._scheduler_config,
                    },
                }
            )
        return config

    def training_step(
        self, train_batch: Union[Data, List[Data]], batch_idx: int
    ) -> Tensor:
        """Perform training step."""
        if isinstance(train_batch, Data):
            train_batch = [train_batch]
        loss = self.shared_step(train_batch, batch_idx)
        self.log(
            "train_loss",
            loss,
            batch_size=self._get_batch_size(train_batch),
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        return loss

    def validation_step(
        self, val_batch: Union[Data, List[Data]], batch_idx: int
    ) -> Tensor:
        """Perform validation step."""
        if isinstance(val_batch, Data):
            val_batch = [val_batch]
        loss = self.shared_step(val_batch, batch_idx)
        self.log(
            "val_loss",
            loss,
            batch_size=self._get_batch_size(val_batch),
            prog_bar=True,
            on_epoch=True,
            on_step=False,
            sync_dist=True,
        )
        return loss

    def inference(self) -> None:
        """Activate inference mode."""
        for task in self._tasks:
            task.inference()

    def train(self, mode: bool = True) -> "Model":
        """Deactivate inference mode."""
        super().train(mode)
        if mode:
            for task in self._tasks:
                task.train_eval()
        return self

    def predict(
        self,
        dataloader: DataLoader,
        gpus: Optional[Union[List[int], int]] = None,
        distribution_strategy: Optional[str] = "auto",
    ) -> List[Tensor]:
        """Return predictions for `dataloader`."""
        self.inference()
        self.train(mode=False)

        callbacks = self._create_default_callbacks(
            val_dataloader=None,
        )

        inference_trainer = self._construct_trainer(
            gpus=gpus,
            distribution_strategy=distribution_strategy,
            callbacks=callbacks,
        )

        predictions_list = inference_trainer.predict(self, dataloader)
        assert len(predictions_list), "Got no predictions"

        nb_outputs = len(predictions_list[0])
        predictions: List[Tensor] = [
            torch.cat([preds[ix] for preds in predictions_list], dim=0)
            for ix in range(nb_outputs)
        ]
        return predictions

    def predict_as_dataframe(
        self,
        dataloader: DataLoader,
        prediction_columns: Optional[List[str]] = None,
        *,
        additional_attributes: Optional[List[str]] = None,
        gpus: Optional[Union[List[int], int]] = None,
        distribution_strategy: Optional[str] = "auto",
    ) -> pd.DataFrame:
        """Return predictions for `dataloader` as a DataFrame.

        Include `additional_attributes` as additional columns in the output
        DataFrame.
        """
        if prediction_columns is None:
            prediction_columns = self.prediction_labels

        if additional_attributes is None:
            additional_attributes = []
        assert isinstance(additional_attributes, list)

        if (
            not isinstance(dataloader.sampler, SequentialSampler)
            and additional_attributes
        ):
            print(dataloader.sampler)
            raise UserWarning(
                "DataLoader has a `sampler` that is not `SequentialSampler`, "
                "indicating that shuffling is enabled. Using "
                "`predict_as_dataframe` with `additional_attributes` assumes "
                "that the sequence of batches in `dataloader` are "
                "deterministic. Either call this method a `dataloader` which "
                "doesn't resample batches; or do not request "
                "`additional_attributes`."
            )
        self.info(f"Column names for predictions are: \n {prediction_columns}")
        predictions_torch = self.predict(
            dataloader=dataloader,
            gpus=gpus,
            distribution_strategy=distribution_strategy,
        )
        predictions = (
            torch.cat(predictions_torch, dim=1).detach().cpu().numpy()
        )
        assert len(prediction_columns) == predictions.shape[1], (
            f"Number of provided column names ({len(prediction_columns)}) and "
            f"number of output columns ({predictions.shape[1]}) don't match."
        )

        # Get additional attributes
        attributes: Dict[str, List[np.ndarray]] = OrderedDict(
            [(attr, []) for attr in additional_attributes]
        )
        for batch in dataloader:
            for attr in attributes:
                attribute = batch[attr]
                if isinstance(attribute, torch.Tensor):
                    attribute = attribute.detach().cpu().numpy()

                # Check if node level predictions
                # If true, additional attributes are repeated
                # to make dimensions fit
                if len(predictions) != len(dataloader.dataset):
                    if len(attribute) < np.sum(
                        batch.n_pulses.detach().cpu().numpy()
                    ):
                        attribute = np.repeat(
                            attribute, batch.n_pulses.detach().cpu().numpy()
                        )
                        try:
                            assert len(attribute) == len(batch.x)
                        except AssertionError:
                            self.warning_once(
                                "Could not automatically adjust length"
                                f"of additional attribute {attr} to match length of"
                                f"predictions. Make sure {attr} is a graph-level or"
                                "node-level attribute. Attribute skipped."
                            )
                            pass
                attributes[attr].extend(attribute)

        data = np.concatenate(
            [predictions]
            + [
                np.asarray(values)[:, np.newaxis]
                for values in attributes.values()
            ],
            axis=1,
        )

        results = pd.DataFrame(
            data, columns=prediction_columns + additional_attributes
        )
        return results

    def _create_default_callbacks(
        self,
        val_dataloader: DataLoader,
        early_stopping_patience: Optional[int] = None,
    ) -> List:
        """Create default callbacks.

        Used in cases where no callbacks are specified by the user in .fit
        """
        callbacks = [ProgressBar()]
        if val_dataloader is not None:
            assert early_stopping_patience is not None
            # Add Early Stopping
            callbacks.append(
                EarlyStopping(
                    monitor="val_loss",
                    patience=early_stopping_patience,
                )
            )
            # Add Model Check Point
            callbacks.append(
                ModelCheckpoint(
                    save_top_k=1,
                    monitor="val_loss",
                    mode="min",
                    filename=f"{self.backbone.__class__.__name__}"
                    + "-{epoch}-{val_loss:.2f}-{train_loss:.2f}",
                )
            )
            self.info(
                f"EarlyStopping has been added with a patience of {early_stopping_patience}."
            )
        return callbacks

    def _add_early_stopping(
        self, val_dataloader: DataLoader, callbacks: List
    ) -> List:
        if val_dataloader is None:
            return callbacks
        has_early_stopping = False
        assert isinstance(callbacks, list)
        for callback in callbacks:
            if isinstance(callback, EarlyStopping):
                has_early_stopping = True

        if not has_early_stopping:
            callbacks.append(
                EarlyStopping(
                    monitor="val_loss",
                    patience=5,
                )
            )
            self.warning_once(
                "Got validation dataloader but no EarlyStopping callback. An "
                "EarlyStopping callback has been added automatically with "
                "patience=5 and monitor = 'val_loss'."
            )
        return callbacks

We should probably consider whether @abstractmethod is suitable for these functions - that decorator forces the input arguments and that will probably not be flexible enough for us. Other libraries simply assumes the users knows that the function must be implemented without declaring that explicitly (i.e. torch.nn.Module.forward). But that's a detail..

This freedom allows us to implement machine learning paradigms that have a different number of model components (encoder/decoder pair) and to some-what freely define forward passes, loss computations etc.

Our usual StandardModel would look like so under this refactor:

class StandardModel(EasyModel):
    def __init__(
        self,
        backbone: GNN = None,
        gnn: Optional[GNN] = None,
        **easy_model_kwargs: Any
    ) -> None:
        """Construct `StandardModel`."""
        # Base class constructor
        super().__init__(**easy_model_kwargs)

        # deprecation warnings
        if (backbone is None) & (gnn is not None):
            backbone = gnn
            # Code continues after warning
            self.warning(
                """DeprecationWarning: Argument `gnn` will be deprecated in GraphNeT 2.0. Please use `backbone` instead."""
            )
        elif (backbone is None) & (gnn is None):
            # Code stops
            raise TypeError(
                "__init__() missing 1 required keyword-only argument: 'backbone'"
            )
        assert isinstance(backbone, GNN)

        # Member variable(s)
        self.backbone = backbone

    def compute_loss(
        self, preds: Tensor, data: List[Data], verbose: bool = False
    ) -> Tensor:
        """Compute and sum losses across tasks."""
        data_merged = {}
        target_labels_merged = list(set(self.target_labels))
        for label in target_labels_merged:
            data_merged[label] = torch.cat([d[label] for d in data], dim=0)
        for task in self._tasks:
            if task._loss_weight is not None:
                data_merged[task._loss_weight] = torch.cat(
                    [d[task._loss_weight] for d in data], dim=0
                )

        losses = [
            task.compute_loss(pred, data_merged)
            for task, pred in zip(self._tasks, preds)
        ]
        if verbose:
            self.info(f"{losses}")
        assert all(
            loss.dim() == 0 for loss in losses
        ), "Please reduce loss for each task separately"
        return torch.sum(torch.stack(losses))

    def forward(
        self, data: Union[Data, List[Data]]
    ) -> List[Union[Tensor, Data]]:
        """Forward pass, chaining model components."""
        if isinstance(data, Data):
            data = [data]
        x_list = []
        for d in data:
            x = self.backbone(d)
            x_list.append(x)
        x = torch.cat(x_list, dim=0)

        preds = [task(x) for task in self._tasks]
        return preds

    def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
        """Perform shared step.

        Applies the forward pass and the following loss calculation, shared
        between the training and validation step.
        """
        preds = self(batch)
        loss = self.compute_loss(preds, batch)
        return loss

    def validate_tasks(self):
        accepted_tasks = (StandardLearnedTask)
        for task in self._tasks:
            assert isinstance(task, accepted_tasks)

A masked autoencoder could look like so:

class StandardMaskedAutoEncoder(EasyModel):
    def __init__(
        self,
        encoder: GNN = None,
        decoder: GNN = None,
        mask_func: Callable = None,
        **easy_model_kwargs: Any
    ) -> None:
        """Construct `StandardModel`."""
        # Base class constructor
        super().__init__(**easy_model_kwargs)

        self.encoder = encoder
        self.decoder = decoder
        self._mask_func = mask_func

    def forward(
        self, data: Union[Data, List[Data]]
    ) -> Tuple[Tensor, Tensor]:
        """Forward pass, chaining model components."""
        if isinstance(data, Data):
            data = [data]
        y_list = []
        x_list = []
        encoded_x_list = []
        mask_list = []
        for d in data:
            #Mask input
            x_masked, mask = self._mask_func(d.x)

            # Encode mask input
            x_masked_encoded = self.encoder(x_masked)

            # Decode masked input
            y = self.decoder(x_masked_encoded)

            # Store raw input, mask and decoded input
            y_list.append(y)
            x_list.append(d.x)
            mask_list.append(mask)
            encoded_x_list.append(x_masked_encoded)
        y = torch.cat(y_list, dim=0)
        x = torch.cat(x_list, dim = 0)
        mask = torch.cat(mask_list, dim = 0)
        encoded_x = torch.cat(x_masked_encoded, dim = 0)
        preds = [task(x = x, y = y, x_encoded = encoded_x, mask = mask) for task in self._tasks]

        assert preds.shape[0] == x[mask,:].shape[0]
        # return masked predictions and ground truth
        return preds, x[mask,:]

    def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
        """Perform shared step.

        Applies the forward pass and the following loss calculation, shared
        between the training and validation step.
        """
        preds, x, = self(batch)
        loss = self.compute_loss(preditions = preds, truth = x)
        return loss

    def compute_loss(
        self, preds: Tensor, truth: Tensor, verbose: bool = False
    ) -> Tensor:
        """Compute and sum losses across tasks."""

        losses = [
            task.compute_loss(pred, truth)
            for task, pred in zip(self._tasks, preds)
        ]
        if verbose:
            self.info(f"{losses}")
        assert all(
            loss.dim() == 0 for loss in losses
        ), "Please reduce loss for each task separately"
        return torch.sum(torch.stack(losses))

    def validate_tasks(self):
        accepted_tasks = (AutoEncoderTask)
        for task in self._tasks:
            assert isinstance(task, accepted_tasks)

    def save_state_dict(self, path:str) -> None:
        "Save state dict for entire model and decoder/encoder pair seperately."
        super().save_state_dict(path)
        self.encoder.save_state_dict(path.replace('.pth', '_encoder.pth'))
        self.decoder.save_state_dict(path.replace('.pth', '_decoder.pth'))

    def save_config(self, path:str) -> None:
        "Save config for entire model and decoder/encoder pair seperately."
        super().save_state_dict(path)
        self.encoder.save_config(path.replace('.yml', '_encoder.yml'))
        self.decoder.save_config(path.replace('.yml', '_decoder.yml'))

Because the model inherits from EasyModel 1. is met. The masking is done in the forward-pass; notice that this masking method is an input and that is assumed to be functional on batched data (meets 2. and 3.). The class attempts to save the encoder/decoder pair (along with the entire model) when .save methods are called (meets 4.). Notice also that both the original input, encoded input, decoded input and the mask is passed to the Task. This means that we need a special AutoEncodingTask that becomes a sandbox for tying these inputs into the final prediction. With all of that information available, I think most autoencoding tasks should be supported. Here's some pseudo-code on how that AutoEncodingTask could look like:

class AutoEncodingTask(Task):

    @final
    def forward(
        self, x: Tensor, y: Tensor, encoding: Tensor, mask: Tensor
    ) -> Union[Tensor, Data]:
        """AutoEncoding Task. Ties together input, decoded input, encoded input and mask.

        Args:
            x: original input tensor to decoder
            y: output of decoder
            encoding: output of encoder
            mask: mask applied to input before being passed to encoder.

        Returns:
            ...
        """
        self._regularisation_loss = 0  # Reset
        x = self._forward(x = x, y = y, encoding = encoding, mask = mask)
        return self._transform_prediction(x)

    @abstractmethod
    def _forward(
        self, x: Tensor, y: Tensor, encoding: Tensor, mask: Tensor
    ) -> Tensor:
        """ Tie all the data together for a final prediction."""

Finally, to address 5. , we can introduce a method Model.partially_load_state_dict that will load in the entries in state_dict that exists in Model. That would allow us to use pretrained weights in fresh models that might contain new, additional layers.

While there is still a couple of minor details to iron out, I do think this approach would suit your use-case. @AMHermansen could you perhaps take a look and let me know what you think?

AMHermansen commented 10 months ago

Sorry for the wait.

Alright @AMHermansen - sorry for the delay on this. I now have some pseudo-code ready.

I arrived at this pseudo-code by starting out with a series on considerations. First, I think we should insist that the user experience is intuitive and modular, such that auto encoding functionality in graphnet plays well with existing modules, whether that would be loss functions, model architectures, graph definitions etc. Secondly, I think it's important that when we implement new machine learning paradigms, we do it in a way that is true to how they are usually perceived, hopefully increasing the readability for new graphnet users. Thirdly, I think the implementations should aim to be flexible enough for users to apply autoencoding in graphnet to most usecases; and the few that it does not should be intuitive to implement.

These three considerations lead me to answers to a couple of the points you raised above:

1. The syntax should be similar to what I showed earlier.

2. Masking should happen in a forward pass in the autoencoder because that is where people expect it to be.

3. Masking should be modular; i.e. the implementation should allow for an argument for some kind of callable that delivers the masking.

4. The implementation should allow for saving `state_dicts` and `model_configs` for the encoder/decoder pair independently.

5. We should introduce a `Model` method that allows for a partial `state_dict` load.

6. It is OK that autoencoders only work on autoencoding-tasks.

I think this sounds reasonable, though I don't see why you allow the MAE to work with unique AE-tasks, but not have it rely on a unique GraphDefinition.

Just to iron it out, so we are on the same page. My current implementation splits the masking procedure into two steps. The first step is to select which nodes should be masked, this is done in preprocessing in a unique NodeDefinition and then in the forward pass of the model the nodes that were previously selected are removed. This makes the logic much simpler, since you don't have to deal with the torch_geoemtric sparse_tensor while performing logic on an event level basis. The only real alternative I see to this, is writing a for loop over all the events in the batch. The second step is when the nodes are removed, and later MaskTokens are reintroduced. Both of these are done in the forward pass of the MAE.

I think "the best approach" is the approach that achieves these goals with the smallest implementation overhead and that does not repeat code. After some experimentation, I found that a nice way of doing this would be to first introduce a small refactor of StandardModel that moves boiler-plate code into a subclass currently called EasyModel. This class makes no assumption on number of model architectures or what they are, and it has four abstract methods: compute_loss, forward, shared_step, validate_tasks. It delivers the easy syntax our users are relying on; model.fit, model.predict_as_dataframe, etc. Our current StandardModel becomes a subclass of EasyModel with a specific implementation of the abstract methods that is aimed for single-architecture training. Pseudo-code below:

I'm not personally using the Model.fit nor the Model.predict(_as_dataframe) methods, so I'm not sure how my model would work with those, and I don't personally plan on figuring out how to implement such logic.

The refactor is probably well motivated, but maybe it is better done in a separate Issue then, which this implementation would rely on?

A masked autoencoder could look like so:

class StandardMaskedAutoEncoder(EasyModel):
    def __init__(
        self,
        encoder: GNN = None,
        decoder: GNN = None,
        mask_func: Callable = None,
        **easy_model_kwargs: Any
    ) -> None:
        """Construct `StandardModel`."""
        # Base class constructor
        super().__init__(**easy_model_kwargs)

        self.encoder = encoder
        self.decoder = decoder
        self._mask_func = mask_func

    def forward(
        self, data: Union[Data, List[Data]]
    ) -> Tuple[Tensor, Tensor]:
        """Forward pass, chaining model components."""
        if isinstance(data, Data):
            data = [data]
        y_list = []
        x_list = []
        encoded_x_list = []
        mask_list = []
        for d in data:
            #Mask input
            x_masked, mask = self._mask_func(d.x)

            # Encode mask input
            x_masked_encoded = self.encoder(x_masked)

            # Decode masked input
            y = self.decoder(x_masked_encoded)

            # Store raw input, mask and decoded input
            y_list.append(y)
            x_list.append(d.x)
            mask_list.append(mask)
            encoded_x_list.append(x_masked_encoded)
        y = torch.cat(y_list, dim=0)
        x = torch.cat(x_list, dim = 0)
        mask = torch.cat(mask_list, dim = 0)
        encoded_x = torch.cat(x_masked_encoded, dim = 0)
        preds = [task(x = x, y = y, x_encoded = encoded_x, mask = mask) for task in self._tasks]

        assert preds.shape[0] == x[mask,:].shape[0]
        # return masked predictions and ground truth
        return preds, x[mask,:]

    def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
        """Perform shared step.

        Applies the forward pass and the following loss calculation, shared
        between the training and validation step.
        """
        preds, x, = self(batch)
        loss = self.compute_loss(preditions = preds, truth = x)
        return loss

    def compute_loss(
        self, preds: Tensor, truth: Tensor, verbose: bool = False
    ) -> Tensor:
        """Compute and sum losses across tasks."""

        losses = [
            task.compute_loss(pred, truth)
            for task, pred in zip(self._tasks, preds)
        ]
        if verbose:
            self.info(f"{losses}")
        assert all(
            loss.dim() == 0 for loss in losses
        ), "Please reduce loss for each task separately"
        return torch.sum(torch.stack(losses))

    def validate_tasks(self):
        accepted_tasks = (AutoEncoderTask)
        for task in self._tasks:
            assert isinstance(task, accepted_tasks)

    def save_state_dict(self, path:str) -> None:
        "Save state dict for entire model and decoder/encoder pair seperately."
        super().save_state_dict(path)
        self.encoder.save_state_dict(path.replace('.pth', '_encoder.pth'))
        self.decoder.save_state_dict(path.replace('.pth', '_decoder.pth'))

    def save_config(self, path:str) -> None:
        "Save config for entire model and decoder/encoder pair seperately."
        super().save_state_dict(path)
        self.encoder.save_config(path.replace('.yml', '_encoder.yml'))
        self.decoder.save_config(path.replace('.yml', '_decoder.yml'))

The implementation you suggest above currently doesn't re-introduce mask-tokens after the masked-input has been encoded. This makes it difficult to make node-level predictions because the model doesn't necessarily know how many nodes were removed, and thus it doesn't know how many predictions to make.

I like how state_dicts / configs are saved separately, but maybe it should also have either a custom way of loading multiple inputs, or save it's own config/state_dict, in case you want to use the entire MAE - for other purposes.

It might also make sense to look into how ModelCheckpoints work in Lightning to see if it is possible to have an elegant way of using those for down-stream purposes.

Because the model inherits from EasyModel 1. is met. The masking is done in the forward-pass; notice that this masking method is an input and that is assumed to be functional on batched data (meets 2. and 3.). The class attempts to save the encoder/decoder pair (along with the entire model) when .save methods are called (meets 4.). Notice also that both the original input, encoded input, decoded input and the mask is passed to the Task. This means that we need a special AutoEncodingTask that becomes a sandbox for tying these inputs into the final prediction. With all of that information available, I think most autoencoding tasks should be supported. Here's some pseudo-code on how that AutoEncodingTask could look like:

class AutoEncodingTask(Task):

    @final
    def forward(
        self, x: Tensor, y: Tensor, encoding: Tensor, mask: Tensor
    ) -> Union[Tensor, Data]:
        """AutoEncoding Task. Ties together input, decoded input, encoded input and mask.

        Args:
            x: original input tensor to decoder
            y: output of decoder
            encoding: output of encoder
            mask: mask applied to input before being passed to encoder.

        Returns:
            ...
        """
        self._regularisation_loss = 0  # Reset
        x = self._forward(x = x, y = y, encoding = encoding, mask = mask)
        return self._transform_prediction(x)

    @abstractmethod
    def _forward(
        self, x: Tensor, y: Tensor, encoding: Tensor, mask: Tensor
    ) -> Tensor:
        """ Tie all the data together for a final prediction."""

Finally, to address 5. , we can introduce a method Model.partially_load_state_dict that will load in the entries in state_dict that exists in Model. That would allow us to use pretrained weights in fresh models that might contain new, additional layers.

Wouldn't you want to also pass data to the AutoEncoding task, to make sure it has access to the original data?