graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
85 stars 85 forks source link

Jammy flow integration #728

Open RasmusOrsoe opened 1 month ago

RasmusOrsoe commented 1 month ago

This PR adds support for Normalizing Flows via the jammy_flows package, and therefore supersedes #649. The benefit of using jammy_flowsis that it contains many different normalizing flows, and that we avoid maintaining that code ourselves :-).

The package is not listed as a direct dependency but used as an optional support package.

Specifically, this PR introduces the following major changes:

  1. StandardFlowTask now uses jammy_flows to construct pdfs of any kind that it supports. These pdfs can be both conditional and non-conditional. Conditional flows can be conditioned on latent model output, event-level information or pulse-level information.
  2. A new model class is added: NormalizingFlow which work with the StandardFlowTask. Usage is similar to StandardModel.
  3. An example of training a conditional flow is added under examples/04_training/07_train_normalizing_flow.py
  4. has_jammy_flows_package() is added under graphnet.utils.imports to check if its installed, and is used in a few places to make sure that the code runs also for people who choose not to install jammy_flows.

Minor changes:

  1. repeat_labels is added as an argument to GraphDefinition - if True, event-level information, .e.g energy is repeated row-wise to match the number of pulses in the event. This feature was added in this PR to make it possible to build flows that learn pulse-level pdfs conditioned on event-level information.
  2. Installation matrix is updated to provide a note on the installation of jammy flows
  3. Github workflows is adjusted to run with jammy_flows installed.
  4. **kwargs for Trainer is added for predict-methods in EasySyntax to allow the same level of control over Trainer arguments as we have for .fit