graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
90 stars 92 forks source link

Make training scripts simpler #262

Closed RasmusOrsoe closed 1 year ago

RasmusOrsoe commented 2 years ago

In /examples/train_model.py the MWE of a training script is 150 lines of code. If we introduce a default configuration for each model (model.config), a Session class we could simplify this greatly. Here is a bit of conceptual code

from graphnet.models.training.utils import TrainingSession, InferenceSession, make_dataloaders
from graphnet.models.gnn import DynEdge
from graphnet.components.loss_functions import LogCoshLoss
from graphnet.models.task.reconstruction import EnergyReconstruction
from graphnet.models.detector.icecube import IceCubeDeepCore

# Main function definition
def main(model_config = None):

    archive = "/groups/icecube/asogaard/gnn/results/"
    run_name = "dynedge_{}_example".format(config["target"])
    selection = None

    model = DynEdge()
    # Configuration
    if model_config is None:
        model.config['loss_function'] = LogCoshLoss
        model.config['task'] = EnergyReconstruction
        model.config['data_path'] = "/groups/icecube/asogaard/data/sqlite/dev_lvl7_robustness_muon_neutrino_0000/data/dev_lvl7_robustness_muon_neutrino_0000.db"
        model.config['pulsemap'] = 'SRTTWOfflinePulsesDC'
        model.config['detector'] = IeCubeDeepCore
        model.config['selection'] = selection
    else:
        model.config = model_config

    # Make DataLoaders
    train_dataloader, valid_dataloader, test_dataloader = make_dataloaders(model)

    # Setup Sessions
    training_session = TrainingSession(model, archive, run_name)
    inference_session = InferenceSession(model, archive, run_name)

    # training session
    trained_model = training_session.start(train_dataloader = train_dataloader, validation_dataloader = valid_dataloader)

    # inference session
    if trained_model.converged:
        validation_results = inference_session.start(model = trained_model, dataloader = valid_dataloader, archive, run_name, tag = 'validation')
        test_results = inference_session.start(model = trained_model, dataloader = valid_dataloader, archive, run_name, tag = 'test')

    # save results to csv
    save_results(validation_results, trained_model, run_name, archive)
    save_results(test_results, trained_model, run_name, archive)

# Main function call
if __name__ == "__main__":
    main()
asogaard commented 2 years ago

Yesss! 👏 I think this is a brilliant initiative that ties in with #226, #175, and more. And thanks for the conceptual code — I think it is the right thing to do to start from "what would we ideally like the code to look like" and then work backwards from there.

RasmusOrsoe commented 2 years ago

Cool! I think it could become quite elegant if we use the model as a "vechicle" for the auxiliary data, such as batch_size, epochs, etc. This way, we also ensure it's actually saved with the model.

asogaard commented 1 year ago

@RasmusOrsoe, do I remember correctly that you had some ideas for this (scikit-learn-ish)?

RasmusOrsoe commented 1 year ago

@asogaard yes. I have not had the time to actually begin working on it, but I wanted to do something, such that a MWE of training could be boiled down to:

import graphnet
from graphnet.shelf import DynEdge # user friendly, pre-configured model class that reconstructs / classifies several targets ("on the shelf") - training loop, prediction loop etc. is all pre-defined via new code
from graphnet.plotting import resolution_plots

dataset = graphnet.dataset.low_energy('path_to_download_dataset') # ala https://pytorch.org/vision/stable/datasets.html

trained_model = DynEdge.fit(dataset.train)

my_predictions = DynEdge.predict(dataset.test)

my_resolution_figure = resolution_plots(my_predictions)
asogaard commented 1 year ago

Alright, I can start having a look!

RasmusOrsoe commented 1 year ago

Please feel free to! I would love to chip in, but I'm afraid it will be some time before I can give this a go.