graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
90 stars 92 forks source link

01_train_model.py fails at upgrade data #385

Closed MoustHolmes closed 1 year ago

MoustHolmes commented 1 year ago

I have tried and making the train_model.py run on upgrade data but it fails at sanity check where it seams like it is reading the graphs wrong my model config is as follows

arguments:
  coarsening: null
  detector:
    ModelConfig:
      arguments:
        graph_builder:
          ModelConfig:
            arguments: {columns: null, nb_nearest_neighbours: 8}
            class_name: KNNGraphBuilder
        scalers: null
      class_name: IceCubeUpgrade
  gnn:
    ModelConfig:
      arguments:
        add_global_variables_after_pooling: false
        dynedge_layer_sizes: null
        features_subset: null
        global_pooling_schemes: [min, max, mean, sum]
        nb_inputs: 7
        nb_neighbours: 8
        post_processing_layer_sizes: null
        readout_layer_sizes: null
      class_name: DynEdge
  optimizer_class: '!class torch.optim.adam Adam'
  optimizer_kwargs: {eps: 0.001, lr: 1e-05}
  scheduler_class: '!class torch.optim.lr_scheduler ReduceLROnPlateau'
  scheduler_config: {frequency: 1, monitor: val_loss}
  scheduler_kwargs: {patience: 5}
  tasks:
  - ModelConfig:
      arguments:
        hidden_size: 128
        loss_function:
          ModelConfig:
            arguments: {}
            class_name: LogCoshLoss
        loss_weight: null
        target_labels: energy
        transform_inference: null
        transform_prediction_and_target: '!lambda x: torch.log10(x)'
        transform_support: null
        transform_target: null
      class_name: EnergyReconstruction
class_name: StandardModel

my dataset config is as follows

path: /groups/icecube/petersen/GraphNetDatabaseRepository/nmo_analysis/data/140028_upgrade_NuMu/merged_140021.db
pulsemaps:
  - SplitInIcePulses_GraphSage_Pulses
features:
  - dom_x
  - dom_y
  - dom_z
  - dom_time
  - charge
  - rde
  - pmt_area
truth:
  - energy
  - position_x
  - position_y
  - position_z
  - azimuth
  - zenith
  - pid
  - elasticity
  - sim_type
  - interaction_type
index_column: event_no
truth_table: truth
seed: 21
selection:
  test: event_no % 5 == 0
  validation: event_no % 5 == 1
  train: event_no % 5 > 1

the error i get

Traceback (most recent call last):
  File "train_upgrade_model.py", line 159, in <module>
    main(
  File "train_upgrade_model.py", line 92, in main
    model.fit(
  File "/lustre/hpc/icecube/moust/work/graphnet/src/graphnet/models/model.py", line 80, in fit
    trainer.fit(
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 696, in fit
    self._call_and_handle_interrupt(
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 648, in _call_and_handle_interrupt
    return self.strategy.launcher.launch(trainer_fn, *args, trainer=self, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 93, in launch
    return function(*args, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 735, in _fit_impl
    results = self._run(model, ckpt_path=self.ckpt_path)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1166, in _run
    results = self._run_stage()
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1252, in _run_stage
    return self._run_train()
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1274, in _run_train
    self._run_sanity_check()
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1343, in _run_sanity_check
    val_loop.run()
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 155, in advance
    dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 200, in run
    self.advance(*args, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 143, in advance
    output = self._evaluation_step(**kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 240, in _evaluation_step
    output = self.trainer._call_strategy_hook(hook_name, *kwargs.values())
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1704, in _call_strategy_hook
    output = fn(*args, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/strategies/ddp.py", line 358, in validation_step
    return self.model(*args, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 965, in forward
    output = self.module(*inputs, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/pytorch_lightning/overrides/base.py", line 90, in forward
    return self.module.validation_step(*inputs, **kwargs)
  File "/lustre/hpc/icecube/moust/work/graphnet/src/graphnet/models/standard_model.py", line 122, in validation_step
    loss = self.shared_step(val_batch, batch_idx)
  File "/lustre/hpc/icecube/moust/work/graphnet/src/graphnet/models/standard_model.py", line 102, in shared_step
    preds = self(batch)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/lustre/hpc/icecube/moust/work/graphnet/src/graphnet/models/standard_model.py", line 91, in forward
    data = self._detector(data)
  File "/groups/icecube/moust/miniconda3/envs/graphnet/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "/lustre/hpc/icecube/moust/work/graphnet/src/graphnet/models/detector/detector.py", line 47, in forward
    assert data.x.size()[1] == self.nb_inputs, (
AssertionError: ('Got graph data with incompatible size, ', 'torch.Size([414, 7]) vs. 14 expected')

Additional context Add any other context about the problem here.

asogaard commented 1 year ago

In your DatasetConfig, make sure you're using the IceCube-Upgrade set of features in graphnet.data.constants.

MoustHolmes commented 1 year ago

Thank you that worked! maybe we should include a model and dataset config for upgrade in the exambles