Open asogaard opened 1 year ago
This is the current layout of the MAE that I've been working. The general MAE architecture is based of Masked Autoencoders Are Scalable Vision Learners
The aggregation loss is currently computed by predicting logits for dom PID, and the duplicating the logits for each masked pulse. The motivation for this novel loss function, is to have a permutation invariant loss function, since during prototyping it was observed that even with Positional Encodings the model wasn't able to distinguish nodes to a satisfiable level.
It is possible to add a GNN-decoder. Which will take in the features learned by the encoder and a Mask-Token (learned parameters in the latent space of the encoder.) The Decoder will then try to do a more node-level reconstruction. It should also be possible to "tokenize" the PIDs for the masked DOM's and then only do reconstruction non-positional variables. The idea behind this, is that the "predicted" time should very much depend on where the model suspects the DOM was located, but since it might confuse DOM's with one-another I think we might get better results by providing the true-doms and then tasking it with finding the time/charge for the DOM. Alternative loss functions for this part could be ChamferDistance see Point-M2AE eq.2.
This diagram also illustrates why I think Aggregation might be better suited as a separate module, since we might need to extract both the graph (to add mask-tokens) and then do aggregation separately.
I'm open to inputs on the architecture. Currently my model is implemented using encoder-transformers as the GNN-encoder/decoder.
Hey @AMHermansen I've now taken a deeper look at this, and I must admit I still struggle to follow your diagram. I think that your diagram depicts the usual two-stage approach for using autoencoders where one first trains the encoder-decoder pair in a unsupervised setting, and then secondly finetune one of components of the encoder-decoder pair on a supervised learning task. Assuming this, I think the intended workflow is straight forward:
The first stage pre-trains the encoder-decoder pair according to some scheme defined in an AutoEncoding
task, and the encoder and decoder is saved separately for later use.
from graphnet.models import AutoEncoder
from graphnet.models.gnn import DynEdge
from graphnet.models.detector import IceCube86
from graphnet.models.graphs import KNNGraph
from graphnet.models.task import AutoEncoding
# Define autoencoder
model = AutoEncoder(graph_definition = KNNGraph(),
detector = IceCube86()
encoder = DynEdge(..),
decoder = DynEdge(..),
task = AutoEncoding()
)
# Train autoencoder pair (first stage)
model.fit(..)
# Save encoder and decoder seperately
model.save_config(path = '..') # Saves the config for encoder and decoder seperately
model.save_state_dict(path = '')# Saves the state_dict for encoder and decoder seperately
In a second session, one can load in a pre-trained component of the encoder-decoder pair and fine tune it on a physics task:
# .. In a different session (2nd stage)
from graphnet.models.gnn import DynEdge
from graphnet.models import StandardModel
from graphnet.models.detector import IceCube86
from graphnet.models.graphs import KNNGraph
from graphnet.models.task import EnergyReconstruction
pretrained_model = DynEdge.from_config('encoder_config.yml')
pretrained_model.load_state_dict('encoder_state_dict')
model = StandardModel(graph_definition = KNNGraph(),
detector = IceCube86()
gnn = pretrained_model,
task = EnergyReconstruction())
model.fit(..)
results = model.predict_as_dataframe(...)
# Save finetuned model
model.save_config(path = '..')
model.save_state_dict(path = '')
I think the technical work required to achieve this syntax is reasonable. @AMHermansen Could you elaborate on how your current approach differs from this?
Hey @AMHermansen I've now taken a deeper look at this, and I must admit I still struggle to follow your diagram. I think that your diagram depicts the usual two-stage approach for using autoencoders where one first trains the encoder-decoder pair in a unsupervised setting, and then secondly finetune one of components of the encoder-decoder pair on a supervised learning task. Assuming this, I think the intended workflow is straight forward:
I would first like to clarify that the diagram describes a masked-autoencoder and not a regular autoencoder. The main difference between the two architectures is found in their general purpose. Masked-autoencoders have been shown to provide promising self-surpervised learning tasks in a multitude of machine learning areas within (NLP and CV), and many State-of-the-Art models for various benchmarks seem to use a masked-autoencoder, as a pretraining step. See e.g. SOTA. Autoencoders are primarily used to take an input and compress it down, to a smaller latent space than the original space.
The first stage pre-trains the encoder-decoder pair according to some scheme defined in an
AutoEncoding
task, and the encoder and decoder is saved separately for later use.from graphnet.models import AutoEncoder from graphnet.models.gnn import DynEdge from graphnet.models.detector import IceCube86 from graphnet.models.graphs import KNNGraph from graphnet.models.task import AutoEncoding # Define autoencoder model = AutoEncoder(graph_definition = KNNGraph(), detector = IceCube86() encoder = DynEdge(..), decoder = DynEdge(..), task = AutoEncoding() ) # Train autoencoder pair (first stage) model.fit(..) # Save encoder and decoder seperately model.save_config(path = '..') # Saves the config for encoder and decoder seperately model.save_state_dict(path = '')# Saves the state_dict for encoder and decoder seperately
I think this is somewhat close to what I'm currently working with. It is not clear to me, where you intend for the masking procedure to take place in the above pseudo code. I'm currently doing it in a specially defined NodeDefinition
, since it feels more natural to have it as a preprocessing step.
A thing which is not clear in the syntax above, is the requirement that the encoder
module is restricted to have been instatiated with aggregation=None
, since it needs to output a graph-like structure for the input to the decoder. To get the summary loss a new module is added, which is responsible for aggregation the graph outputted from the encoder.
Maybe I lack creativity, but I do not see any elegant way around this. The current architecture requires the ability to "intercept" the graph outputted by the encoder and the pass it through 2 different modules.
The only real alternatives I see is to either have GNN output a dictionary containing the final graphstate and the aggregated state. This would then require that tasks "grap" the right entry from the output-dict depending on what they expect as input, or make GNN have to methods learn_features
and aggregate
and then change forward to:
class GNN:
...
def forward(data: Data):
features = self.learn_features(data)
ouput = self.aggregate(features)
return output
In a second session, one can load in a pre-trained component of the encoder-decoder pair and fine tune it on a physics task:
# .. In a different session (2nd stage) from graphnet.models.gnn import DynEdge from graphnet.models import StandardModel from graphnet.models.detector import IceCube86 from graphnet.models.graphs import KNNGraph from graphnet.models.task import EnergyReconstruction pretrained_model = DynEdge.from_config('encoder_config.yml') pretrained_model.load_state_dict('encoder_state_dict') model = StandardModel(graph_definition = KNNGraph(), detector = IceCube86() gnn = pretrained_model, task = EnergyReconstruction()) model.fit(..) results = model.predict_as_dataframe(...) # Save finetuned model model.save_config(path = '..') model.save_state_dict(path = '')
Here I believe you'd run into another problem, since the
pretrained_model
would not have any kind of aggregation making it unsuitable for most of the tasks within GraphNeT. While you could instantiate it with an aggregation components, and then manually load the weights for theDynamicalEdgeConvolution
, and then randomly instantiate the post processing layers, this however seems like a rather crude way of doing it.
My current implementation is mostly done in native pytorch_lightning
and isn't very modularized, since I'm only working with transformer based architectures. To give a quick overview the pseudo code would look something like:
class MAE(Model):
def __init__(self, encoder, decoder, aggregator, ...):
self.encoder = encoder
self.decoder = decoder
self.aggregator = aggregator
...
def encode(self, data: Data):
masked_data = self._apply_masking(data)
encoded_graph = self.encoder(data)
return encoded_graph
def decode(self, encoded_graph: Data, data: Data):
out = self.aggregator(encoded_graph)
decode_graph = self._insert_mask_tokens(encoded_graph, data) # Inserts mask tokens and restores original order.
decode_graph += self.positional_encoder(decode_graph) # Insert positional information into tokens, to make all mask-tokens slightly different
decode_graph = self.decoder(decode_graph)
decoded_tokens = self.select_mask_tokens(decode_graph, data) # Selects the mask tokens.
return out, decoded_tokens
def compute_loss(self, out, decoded_tokens, data):
summary_loss = self.summary_task(out, data)
ae_loss = self.ae_loss(decoded_tokens, data)
return self.summary_weight * summary_loss + ae_loss
def training_step(self, data: Data):
encoded_graph = self.encode(data)
out, decoded_tokens = self.decode(encoded_graph, data)
loss = self.compute_loss(out, decoded_tokens, data)
return {"loss": loss}
...
I hope this clarifies my previous message.
@AMHermansen thanks for the details! Before I respond in full, could you help clear up one last thing:
To get the summary loss a new module is added, which is responsible for aggregation the graph outputted from the encoder.
What is the difference between the purposes of "aggregation loss" and "node level loss"? Are both quantities used in the pre-training stage, or does "node level loss" represent the pre-training stage and "aggregation loss" is loss calculated in the 2nd stage where one fine-tunes the pre-trained component on a graph-level physics task?
@AMHermansen thanks for the details! Before I respond in full, could you help clear up one last thing:
To get the summary loss a new module is added, which is responsible for aggregation the graph outputted from the encoder.
What is the difference between the purposes of "aggregation loss" and "node level loss"? Are both quantities used in the pre-training stage, or does "node level loss" represent the pre-training stage and "aggregation loss" is loss calculated in the 2nd stage where one fine-tunes the pre-trained component on a graph-level physics task?
Yes both the "aggregation loss" and the "node level loss" are used during pretraining.
The idea about using these two losses together is inspired by the BERT-paper (which to my knowledge is the first time a masked-autoencoder is used for pretraining) from natural language processing, where the model is given two sentences and some of the words in each sentence is masked. Then the model is tasked with both predicting the masked words and predicting if the two sentences are a "continuation" of one-another.
Alright @AMHermansen - sorry for the delay on this. I now have some pseudo-code ready.
I arrived at this pseudo-code by starting out with a series on considerations. First, I think we should insist that the user experience is intuitive and modular, such that auto encoding functionality in graphnet plays well with existing modules, whether that would be loss functions, model architectures, graph definitions etc. Secondly, I think it's important that when we implement new machine learning paradigms, we do it in a way that is true to how they are usually perceived, hopefully increasing the readability for new graphnet users. Thirdly, I think the implementations should aim to be flexible enough for users to apply autoencoding in graphnet to most usecases; and the few that it does not should be intuitive to implement.
These three considerations lead me to answers to a couple of the points you raised above:
state_dicts
and model_configs
for the encoder/decoder pair independently. Model
method that allows for a partial state_dict
load. I think "the best approach" is the approach that achieves these goals with the smallest implementation overhead and that does not repeat code. After some experimentation, I found that a nice way of doing this would be to first introduce a small refactor of StandardModel
that moves boiler-plate code into a subclass currently called EasyModel
. This class makes no assumption on number of model architectures or what they are, and it has four abstract methods: compute_loss
, forward
, shared_step
, validate_tasks
. It delivers the easy syntax our users are relying on; model.fit
, model.predict_as_dataframe
, etc. Our current StandardModel
becomes a subclass of EasyModel
with a specific implementation of the abstract methods that is aimed for single-architecture training. Pseudo-code below:
from collections import OrderedDict
from typing import Any, Dict, List, Optional, Union, Type
import numpy as np
import torch
from pytorch_lightning import Callback, Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint
from torch import Tensor
from torch.nn import ModuleList
from torch.optim import Adam
from torch.utils.data import DataLoader, SequentialSampler
from torch_geometric.data import Data
import pandas as pd
from pytorch_lightning.loggers import Logger as LightningLogger
from graphnet.training.callbacks import ProgressBar
from graphnet.models.graphs import GraphDefinition
from graphnet.models.gnn.gnn import GNN
from graphnet.models.model import Model
from graphnet.models.task import StandardLearnedTask
class EasyModel(Model):
"""A Suggested Model format for GraphNeT models.
This class delivers simple user syntax for powerful deep learning
techniques by chains together the different elements of a complete
model (detector read-in, GNN backbone, and task-specific read-outs).
"""
def __init__(
self,
*,
graph_definition: GraphDefinition,
tasks: Union[StandardLearnedTask, List[StandardLearnedTask]],
optimizer_class: Type[torch.optim.Optimizer] = Adam,
optimizer_kwargs: Optional[Dict] = None,
scheduler_class: Optional[type] = None,
scheduler_kwargs: Optional[Dict] = None,
scheduler_config: Optional[Dict] = None,
) -> None:
"""Construct `StandardModel`."""
# Base class constructor
super().__init__(name=__name__, class_name=self.__class__.__name__)
# Check(s)
if not isinstance(tasks, (list, tuple)):
tasks = [tasks]
self.validate_tasks()
assert isinstance(graph_definition, GraphDefinition)
# Member variable(s)
self._graph_definition = graph_definition
self._tasks = ModuleList(tasks)
self._optimizer_class = optimizer_class
self._optimizer_kwargs = optimizer_kwargs or dict()
self._scheduler_class = scheduler_class
self._scheduler_kwargs = scheduler_kwargs or dict()
self._scheduler_config = scheduler_config or dict()
@abstractmethod
def compute_loss(
self, preds: Tensor, data: List[Data], verbose: bool = False
) -> Tensor:
"""Compute and sum losses across tasks."""
raise NotImplementedError(f"Subclasses of `EasyModel` must implement this method`")
@abstractmethod
def forward(
self, data: Union[Data, List[Data]]
) -> List[Union[Tensor, Data]]:
"""Forward pass, chaining model components."""
raise NotImplementedError(f"Subclasses of `EasyModel` must implement this method`")
@abstractmethod
def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
"""Perform shared step.
Applies the forward pass and the following loss calculation, shared
between the training and validation step.
"""
raise NotImplementedError(f"Subclasses of `EasyModel` must implement this method`")
@abstractmethod
def validate_tasks(self):
""" """
raise NotImplementedError(f"Subclasses of `EasyModel` must implement this method`")
@staticmethod
def _construct_trainer(
max_epochs: int = 10,
gpus: Optional[Union[List[int], int]] = None,
callbacks: Optional[List[Callback]] = None,
logger: Optional[LightningLogger] = None,
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
distribution_strategy: Optional[str] = "ddp",
**trainer_kwargs: Any,
) -> Trainer:
if gpus:
accelerator = "gpu"
devices = gpus
else:
accelerator = "cpu"
devices = 1
trainer = Trainer(
accelerator=accelerator,
devices=devices,
max_epochs=max_epochs,
callbacks=callbacks,
log_every_n_steps=log_every_n_steps,
logger=logger,
gradient_clip_val=gradient_clip_val,
strategy=distribution_strategy,
**trainer_kwargs,
)
return trainer
def fit(
self,
train_dataloader: DataLoader,
val_dataloader: Optional[DataLoader] = None,
*,
max_epochs: int = 10,
early_stopping_patience: int = 5,
gpus: Optional[Union[List[int], int]] = None,
callbacks: Optional[List[Callback]] = None,
ckpt_path: Optional[str] = None,
logger: Optional[LightningLogger] = None,
log_every_n_steps: int = 1,
gradient_clip_val: Optional[float] = None,
distribution_strategy: Optional[str] = "ddp",
**trainer_kwargs: Any,
) -> None:
"""Fit `StandardModel` using `pytorch_lightning.Trainer`."""
# Checks
if callbacks is None:
# We create the bare-minimum callbacks for you.
callbacks = self._create_default_callbacks(
val_dataloader=val_dataloader,
early_stopping_patience=early_stopping_patience,
)
self.debug("No Callbacks specified. Default callbacks added.")
else:
# You are on your own!
self.debug("Initializing training with user-provided callbacks.")
pass
self._print_callbacks(callbacks)
has_early_stopping = self._contains_callback(callbacks, EarlyStopping)
has_model_checkpoint = self._contains_callback(
callbacks, ModelCheckpoint
)
if (has_early_stopping) & (has_model_checkpoint is False):
self.warning(
"""No ModelCheckpoint found in callbacks. Best-fit model will not automatically be loaded after training!"""
)
self.train(mode=True)
trainer = self._construct_trainer(
max_epochs=max_epochs,
gpus=gpus,
callbacks=callbacks,
logger=logger,
log_every_n_steps=log_every_n_steps,
gradient_clip_val=gradient_clip_val,
distribution_strategy=distribution_strategy,
**trainer_kwargs,
)
try:
trainer.fit(
self, train_dataloader, val_dataloader, ckpt_path=ckpt_path
)
except KeyboardInterrupt:
self.warning("[ctrl+c] Exiting gracefully.")
pass
# Load weights from best-fit model after training if possible
if has_early_stopping & has_model_checkpoint:
for callback in callbacks:
if isinstance(callback, ModelCheckpoint):
checkpoint_callback = callback
self.load_state_dict(
torch.load(checkpoint_callback.best_model_path)["state_dict"]
)
self.info("Best-fit weights from EarlyStopping loaded.")
def _print_callbacks(self, callbacks: List[Callback]) -> None:
callback_names = []
for cbck in callbacks:
callback_names.append(cbck.__class__.__name__)
self.info(
f"Training initiated with callbacks: {', '.join(callback_names)}"
)
def _contains_callback(
self, callbacks: List[Callback], callback: Callback
) -> bool:
"""Check if `callback` is in `callbacks`."""
for cbck in callbacks:
if isinstance(cbck, callback):
return True
return False
@property
def target_labels(self) -> List[str]:
"""Return target label."""
return [label for task in self._tasks for label in task._target_labels]
@property
def prediction_labels(self) -> List[str]:
"""Return prediction labels."""
return [
label for task in self._tasks for label in task._prediction_labels
]
def configure_optimizers(self) -> Dict[str, Any]:
"""Configure the model's optimizer(s)."""
optimizer = self._optimizer_class(
self.parameters(), **self._optimizer_kwargs
)
config = {
"optimizer": optimizer,
}
if self._scheduler_class is not None:
scheduler = self._scheduler_class(
optimizer, **self._scheduler_kwargs
)
config.update(
{
"lr_scheduler": {
"scheduler": scheduler,
**self._scheduler_config,
},
}
)
return config
def training_step(
self, train_batch: Union[Data, List[Data]], batch_idx: int
) -> Tensor:
"""Perform training step."""
if isinstance(train_batch, Data):
train_batch = [train_batch]
loss = self.shared_step(train_batch, batch_idx)
self.log(
"train_loss",
loss,
batch_size=self._get_batch_size(train_batch),
prog_bar=True,
on_epoch=True,
on_step=False,
sync_dist=True,
)
return loss
def validation_step(
self, val_batch: Union[Data, List[Data]], batch_idx: int
) -> Tensor:
"""Perform validation step."""
if isinstance(val_batch, Data):
val_batch = [val_batch]
loss = self.shared_step(val_batch, batch_idx)
self.log(
"val_loss",
loss,
batch_size=self._get_batch_size(val_batch),
prog_bar=True,
on_epoch=True,
on_step=False,
sync_dist=True,
)
return loss
def inference(self) -> None:
"""Activate inference mode."""
for task in self._tasks:
task.inference()
def train(self, mode: bool = True) -> "Model":
"""Deactivate inference mode."""
super().train(mode)
if mode:
for task in self._tasks:
task.train_eval()
return self
def predict(
self,
dataloader: DataLoader,
gpus: Optional[Union[List[int], int]] = None,
distribution_strategy: Optional[str] = "auto",
) -> List[Tensor]:
"""Return predictions for `dataloader`."""
self.inference()
self.train(mode=False)
callbacks = self._create_default_callbacks(
val_dataloader=None,
)
inference_trainer = self._construct_trainer(
gpus=gpus,
distribution_strategy=distribution_strategy,
callbacks=callbacks,
)
predictions_list = inference_trainer.predict(self, dataloader)
assert len(predictions_list), "Got no predictions"
nb_outputs = len(predictions_list[0])
predictions: List[Tensor] = [
torch.cat([preds[ix] for preds in predictions_list], dim=0)
for ix in range(nb_outputs)
]
return predictions
def predict_as_dataframe(
self,
dataloader: DataLoader,
prediction_columns: Optional[List[str]] = None,
*,
additional_attributes: Optional[List[str]] = None,
gpus: Optional[Union[List[int], int]] = None,
distribution_strategy: Optional[str] = "auto",
) -> pd.DataFrame:
"""Return predictions for `dataloader` as a DataFrame.
Include `additional_attributes` as additional columns in the output
DataFrame.
"""
if prediction_columns is None:
prediction_columns = self.prediction_labels
if additional_attributes is None:
additional_attributes = []
assert isinstance(additional_attributes, list)
if (
not isinstance(dataloader.sampler, SequentialSampler)
and additional_attributes
):
print(dataloader.sampler)
raise UserWarning(
"DataLoader has a `sampler` that is not `SequentialSampler`, "
"indicating that shuffling is enabled. Using "
"`predict_as_dataframe` with `additional_attributes` assumes "
"that the sequence of batches in `dataloader` are "
"deterministic. Either call this method a `dataloader` which "
"doesn't resample batches; or do not request "
"`additional_attributes`."
)
self.info(f"Column names for predictions are: \n {prediction_columns}")
predictions_torch = self.predict(
dataloader=dataloader,
gpus=gpus,
distribution_strategy=distribution_strategy,
)
predictions = (
torch.cat(predictions_torch, dim=1).detach().cpu().numpy()
)
assert len(prediction_columns) == predictions.shape[1], (
f"Number of provided column names ({len(prediction_columns)}) and "
f"number of output columns ({predictions.shape[1]}) don't match."
)
# Get additional attributes
attributes: Dict[str, List[np.ndarray]] = OrderedDict(
[(attr, []) for attr in additional_attributes]
)
for batch in dataloader:
for attr in attributes:
attribute = batch[attr]
if isinstance(attribute, torch.Tensor):
attribute = attribute.detach().cpu().numpy()
# Check if node level predictions
# If true, additional attributes are repeated
# to make dimensions fit
if len(predictions) != len(dataloader.dataset):
if len(attribute) < np.sum(
batch.n_pulses.detach().cpu().numpy()
):
attribute = np.repeat(
attribute, batch.n_pulses.detach().cpu().numpy()
)
try:
assert len(attribute) == len(batch.x)
except AssertionError:
self.warning_once(
"Could not automatically adjust length"
f"of additional attribute {attr} to match length of"
f"predictions. Make sure {attr} is a graph-level or"
"node-level attribute. Attribute skipped."
)
pass
attributes[attr].extend(attribute)
data = np.concatenate(
[predictions]
+ [
np.asarray(values)[:, np.newaxis]
for values in attributes.values()
],
axis=1,
)
results = pd.DataFrame(
data, columns=prediction_columns + additional_attributes
)
return results
def _create_default_callbacks(
self,
val_dataloader: DataLoader,
early_stopping_patience: Optional[int] = None,
) -> List:
"""Create default callbacks.
Used in cases where no callbacks are specified by the user in .fit
"""
callbacks = [ProgressBar()]
if val_dataloader is not None:
assert early_stopping_patience is not None
# Add Early Stopping
callbacks.append(
EarlyStopping(
monitor="val_loss",
patience=early_stopping_patience,
)
)
# Add Model Check Point
callbacks.append(
ModelCheckpoint(
save_top_k=1,
monitor="val_loss",
mode="min",
filename=f"{self.backbone.__class__.__name__}"
+ "-{epoch}-{val_loss:.2f}-{train_loss:.2f}",
)
)
self.info(
f"EarlyStopping has been added with a patience of {early_stopping_patience}."
)
return callbacks
def _add_early_stopping(
self, val_dataloader: DataLoader, callbacks: List
) -> List:
if val_dataloader is None:
return callbacks
has_early_stopping = False
assert isinstance(callbacks, list)
for callback in callbacks:
if isinstance(callback, EarlyStopping):
has_early_stopping = True
if not has_early_stopping:
callbacks.append(
EarlyStopping(
monitor="val_loss",
patience=5,
)
)
self.warning_once(
"Got validation dataloader but no EarlyStopping callback. An "
"EarlyStopping callback has been added automatically with "
"patience=5 and monitor = 'val_loss'."
)
return callbacks
We should probably consider whether @abstractmethod
is suitable for these functions - that decorator forces the input arguments and that will probably not be flexible enough for us. Other libraries simply assumes the users knows that the function must be implemented without declaring that explicitly (i.e. torch.nn.Module.forward). But that's a detail..
This freedom allows us to implement machine learning paradigms that have a different number of model components (encoder/decoder pair) and to some-what freely define forward passes, loss computations etc.
Our usual StandardModel
would look like so under this refactor:
class StandardModel(EasyModel):
def __init__(
self,
backbone: GNN = None,
gnn: Optional[GNN] = None,
**easy_model_kwargs: Any
) -> None:
"""Construct `StandardModel`."""
# Base class constructor
super().__init__(**easy_model_kwargs)
# deprecation warnings
if (backbone is None) & (gnn is not None):
backbone = gnn
# Code continues after warning
self.warning(
"""DeprecationWarning: Argument `gnn` will be deprecated in GraphNeT 2.0. Please use `backbone` instead."""
)
elif (backbone is None) & (gnn is None):
# Code stops
raise TypeError(
"__init__() missing 1 required keyword-only argument: 'backbone'"
)
assert isinstance(backbone, GNN)
# Member variable(s)
self.backbone = backbone
def compute_loss(
self, preds: Tensor, data: List[Data], verbose: bool = False
) -> Tensor:
"""Compute and sum losses across tasks."""
data_merged = {}
target_labels_merged = list(set(self.target_labels))
for label in target_labels_merged:
data_merged[label] = torch.cat([d[label] for d in data], dim=0)
for task in self._tasks:
if task._loss_weight is not None:
data_merged[task._loss_weight] = torch.cat(
[d[task._loss_weight] for d in data], dim=0
)
losses = [
task.compute_loss(pred, data_merged)
for task, pred in zip(self._tasks, preds)
]
if verbose:
self.info(f"{losses}")
assert all(
loss.dim() == 0 for loss in losses
), "Please reduce loss for each task separately"
return torch.sum(torch.stack(losses))
def forward(
self, data: Union[Data, List[Data]]
) -> List[Union[Tensor, Data]]:
"""Forward pass, chaining model components."""
if isinstance(data, Data):
data = [data]
x_list = []
for d in data:
x = self.backbone(d)
x_list.append(x)
x = torch.cat(x_list, dim=0)
preds = [task(x) for task in self._tasks]
return preds
def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
"""Perform shared step.
Applies the forward pass and the following loss calculation, shared
between the training and validation step.
"""
preds = self(batch)
loss = self.compute_loss(preds, batch)
return loss
def validate_tasks(self):
accepted_tasks = (StandardLearnedTask)
for task in self._tasks:
assert isinstance(task, accepted_tasks)
A masked autoencoder could look like so:
class StandardMaskedAutoEncoder(EasyModel):
def __init__(
self,
encoder: GNN = None,
decoder: GNN = None,
mask_func: Callable = None,
**easy_model_kwargs: Any
) -> None:
"""Construct `StandardModel`."""
# Base class constructor
super().__init__(**easy_model_kwargs)
self.encoder = encoder
self.decoder = decoder
self._mask_func = mask_func
def forward(
self, data: Union[Data, List[Data]]
) -> Tuple[Tensor, Tensor]:
"""Forward pass, chaining model components."""
if isinstance(data, Data):
data = [data]
y_list = []
x_list = []
encoded_x_list = []
mask_list = []
for d in data:
#Mask input
x_masked, mask = self._mask_func(d.x)
# Encode mask input
x_masked_encoded = self.encoder(x_masked)
# Decode masked input
y = self.decoder(x_masked_encoded)
# Store raw input, mask and decoded input
y_list.append(y)
x_list.append(d.x)
mask_list.append(mask)
encoded_x_list.append(x_masked_encoded)
y = torch.cat(y_list, dim=0)
x = torch.cat(x_list, dim = 0)
mask = torch.cat(mask_list, dim = 0)
encoded_x = torch.cat(x_masked_encoded, dim = 0)
preds = [task(x = x, y = y, x_encoded = encoded_x, mask = mask) for task in self._tasks]
assert preds.shape[0] == x[mask,:].shape[0]
# return masked predictions and ground truth
return preds, x[mask,:]
def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor:
"""Perform shared step.
Applies the forward pass and the following loss calculation, shared
between the training and validation step.
"""
preds, x, = self(batch)
loss = self.compute_loss(preditions = preds, truth = x)
return loss
def compute_loss(
self, preds: Tensor, truth: Tensor, verbose: bool = False
) -> Tensor:
"""Compute and sum losses across tasks."""
losses = [
task.compute_loss(pred, truth)
for task, pred in zip(self._tasks, preds)
]
if verbose:
self.info(f"{losses}")
assert all(
loss.dim() == 0 for loss in losses
), "Please reduce loss for each task separately"
return torch.sum(torch.stack(losses))
def validate_tasks(self):
accepted_tasks = (AutoEncoderTask)
for task in self._tasks:
assert isinstance(task, accepted_tasks)
def save_state_dict(self, path:str) -> None:
"Save state dict for entire model and decoder/encoder pair seperately."
super().save_state_dict(path)
self.encoder.save_state_dict(path.replace('.pth', '_encoder.pth'))
self.decoder.save_state_dict(path.replace('.pth', '_decoder.pth'))
def save_config(self, path:str) -> None:
"Save config for entire model and decoder/encoder pair seperately."
super().save_state_dict(path)
self.encoder.save_config(path.replace('.yml', '_encoder.yml'))
self.decoder.save_config(path.replace('.yml', '_decoder.yml'))
Because the model inherits from EasyModel
1. is met.
The masking is done in the forward-pass; notice that this masking method is an input and that is assumed to be functional on batched data (meets 2. and 3.). The class attempts to save the encoder/decoder pair (along with the entire model) when .save
methods are called (meets 4.). Notice also that both the original input, encoded input, decoded input and the mask is passed to the Task
. This means that we need a special AutoEncodingTask
that becomes a sandbox for tying these inputs into the final prediction. With all of that information available, I think most autoencoding tasks should be supported. Here's some pseudo-code on how that AutoEncodingTask
could look like:
class AutoEncodingTask(Task):
@final
def forward(
self, x: Tensor, y: Tensor, encoding: Tensor, mask: Tensor
) -> Union[Tensor, Data]:
"""AutoEncoding Task. Ties together input, decoded input, encoded input and mask.
Args:
x: original input tensor to decoder
y: output of decoder
encoding: output of encoder
mask: mask applied to input before being passed to encoder.
Returns:
...
"""
self._regularisation_loss = 0 # Reset
x = self._forward(x = x, y = y, encoding = encoding, mask = mask)
return self._transform_prediction(x)
@abstractmethod
def _forward(
self, x: Tensor, y: Tensor, encoding: Tensor, mask: Tensor
) -> Tensor:
""" Tie all the data together for a final prediction."""
Finally, to address 5. , we can introduce a method Model.partially_load_state_dict
that will load in the entries in state_dict
that exists in Model
. That would allow us to use pretrained weights in fresh models that might contain new, additional layers.
While there is still a couple of minor details to iron out, I do think this approach would suit your use-case. @AMHermansen could you perhaps take a look and let me know what you think?
Sorry for the wait.
Alright @AMHermansen - sorry for the delay on this. I now have some pseudo-code ready.
I arrived at this pseudo-code by starting out with a series on considerations. First, I think we should insist that the user experience is intuitive and modular, such that auto encoding functionality in graphnet plays well with existing modules, whether that would be loss functions, model architectures, graph definitions etc. Secondly, I think it's important that when we implement new machine learning paradigms, we do it in a way that is true to how they are usually perceived, hopefully increasing the readability for new graphnet users. Thirdly, I think the implementations should aim to be flexible enough for users to apply autoencoding in graphnet to most usecases; and the few that it does not should be intuitive to implement.
These three considerations lead me to answers to a couple of the points you raised above:
1. The syntax should be similar to what I showed earlier. 2. Masking should happen in a forward pass in the autoencoder because that is where people expect it to be. 3. Masking should be modular; i.e. the implementation should allow for an argument for some kind of callable that delivers the masking. 4. The implementation should allow for saving `state_dicts` and `model_configs` for the encoder/decoder pair independently. 5. We should introduce a `Model` method that allows for a partial `state_dict` load. 6. It is OK that autoencoders only work on autoencoding-tasks.
I think this sounds reasonable, though I don't see why you allow the MAE to work with unique AE-tasks, but not have it rely on a unique GraphDefinition
.
Just to iron it out, so we are on the same page. My current implementation splits the masking procedure into two steps. The first step is to select which nodes should be masked, this is done in preprocessing in a unique NodeDefinition
and then in the forward pass of the model the nodes that were previously selected are removed. This makes the logic much simpler, since you don't have to deal with the torch_geoemtric sparse_tensor while performing logic on an event level basis. The only real alternative I see to this, is writing a for loop over all the events in the batch.
The second step is when the nodes are removed, and later MaskTokens are reintroduced. Both of these are done in the forward pass of the MAE.
I think "the best approach" is the approach that achieves these goals with the smallest implementation overhead and that does not repeat code. After some experimentation, I found that a nice way of doing this would be to first introduce a small refactor of
StandardModel
that moves boiler-plate code into a subclass currently calledEasyModel
. This class makes no assumption on number of model architectures or what they are, and it has four abstract methods:compute_loss
,forward
,shared_step
,validate_tasks
. It delivers the easy syntax our users are relying on;model.fit
,model.predict_as_dataframe
, etc. Our currentStandardModel
becomes a subclass ofEasyModel
with a specific implementation of the abstract methods that is aimed for single-architecture training. Pseudo-code below:
I'm not personally using the Model.fit
nor the Model.predict(_as_dataframe)
methods, so I'm not sure how my model would work with those, and I don't personally plan on figuring out how to implement such logic.
The refactor is probably well motivated, but maybe it is better done in a separate Issue then, which this implementation would rely on?
A masked autoencoder could look like so:
class StandardMaskedAutoEncoder(EasyModel): def __init__( self, encoder: GNN = None, decoder: GNN = None, mask_func: Callable = None, **easy_model_kwargs: Any ) -> None: """Construct `StandardModel`.""" # Base class constructor super().__init__(**easy_model_kwargs) self.encoder = encoder self.decoder = decoder self._mask_func = mask_func def forward( self, data: Union[Data, List[Data]] ) -> Tuple[Tensor, Tensor]: """Forward pass, chaining model components.""" if isinstance(data, Data): data = [data] y_list = [] x_list = [] encoded_x_list = [] mask_list = [] for d in data: #Mask input x_masked, mask = self._mask_func(d.x) # Encode mask input x_masked_encoded = self.encoder(x_masked) # Decode masked input y = self.decoder(x_masked_encoded) # Store raw input, mask and decoded input y_list.append(y) x_list.append(d.x) mask_list.append(mask) encoded_x_list.append(x_masked_encoded) y = torch.cat(y_list, dim=0) x = torch.cat(x_list, dim = 0) mask = torch.cat(mask_list, dim = 0) encoded_x = torch.cat(x_masked_encoded, dim = 0) preds = [task(x = x, y = y, x_encoded = encoded_x, mask = mask) for task in self._tasks] assert preds.shape[0] == x[mask,:].shape[0] # return masked predictions and ground truth return preds, x[mask,:] def shared_step(self, batch: List[Data], batch_idx: int) -> Tensor: """Perform shared step. Applies the forward pass and the following loss calculation, shared between the training and validation step. """ preds, x, = self(batch) loss = self.compute_loss(preditions = preds, truth = x) return loss def compute_loss( self, preds: Tensor, truth: Tensor, verbose: bool = False ) -> Tensor: """Compute and sum losses across tasks.""" losses = [ task.compute_loss(pred, truth) for task, pred in zip(self._tasks, preds) ] if verbose: self.info(f"{losses}") assert all( loss.dim() == 0 for loss in losses ), "Please reduce loss for each task separately" return torch.sum(torch.stack(losses)) def validate_tasks(self): accepted_tasks = (AutoEncoderTask) for task in self._tasks: assert isinstance(task, accepted_tasks) def save_state_dict(self, path:str) -> None: "Save state dict for entire model and decoder/encoder pair seperately." super().save_state_dict(path) self.encoder.save_state_dict(path.replace('.pth', '_encoder.pth')) self.decoder.save_state_dict(path.replace('.pth', '_decoder.pth')) def save_config(self, path:str) -> None: "Save config for entire model and decoder/encoder pair seperately." super().save_state_dict(path) self.encoder.save_config(path.replace('.yml', '_encoder.yml')) self.decoder.save_config(path.replace('.yml', '_decoder.yml'))
The implementation you suggest above currently doesn't re-introduce mask-tokens after the masked-input has been encoded. This makes it difficult to make node-level predictions because the model doesn't necessarily know how many nodes were removed, and thus it doesn't know how many predictions to make.
I like how state_dicts / configs are saved separately, but maybe it should also have either a custom way of loading multiple inputs, or save it's own config/state_dict, in case you want to use the entire MAE - for other purposes.
It might also make sense to look into how ModelCheckpoints
work in Lightning to see if it is possible to have an elegant way of using those for down-stream purposes.
Because the model inherits from
EasyModel
1. is met. The masking is done in the forward-pass; notice that this masking method is an input and that is assumed to be functional on batched data (meets 2. and 3.). The class attempts to save the encoder/decoder pair (along with the entire model) when .save
methods are called (meets 4.). Notice also that both the original input, encoded input, decoded input and the mask is passed to theTask
. This means that we need a specialAutoEncodingTask
that becomes a sandbox for tying these inputs into the final prediction. With all of that information available, I think most autoencoding tasks should be supported. Here's some pseudo-code on how thatAutoEncodingTask
could look like:class AutoEncodingTask(Task): @final def forward( self, x: Tensor, y: Tensor, encoding: Tensor, mask: Tensor ) -> Union[Tensor, Data]: """AutoEncoding Task. Ties together input, decoded input, encoded input and mask. Args: x: original input tensor to decoder y: output of decoder encoding: output of encoder mask: mask applied to input before being passed to encoder. Returns: ... """ self._regularisation_loss = 0 # Reset x = self._forward(x = x, y = y, encoding = encoding, mask = mask) return self._transform_prediction(x) @abstractmethod def _forward( self, x: Tensor, y: Tensor, encoding: Tensor, mask: Tensor ) -> Tensor: """ Tie all the data together for a final prediction."""
Finally, to address 5. , we can introduce a method
Model.partially_load_state_dict
that will load in the entries instate_dict
that exists inModel
. That would allow us to use pretrained weights in fresh models that might contain new, additional layers.
Wouldn't you want to also pass data
to the AutoEncoding task, to make sure it has access to the original data?
Suggested steps: