graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
94 stars 94 forks source link

FourierEncoder not flexible enough for different detectors #753

Open mobra7 opened 1 month ago

mobra7 commented 1 month ago

Describe the bug The FourierEncoder module assumes that the input data is in the format of (x, y, z, time, charge, auxiliary). However, e.g. for the IceCube86 detector, data comes in the format (x, y, z, time, charge, ..., auxiliary). This runs into an index error because the module has the slicing of the data hard coded.

To Reproduce Run the examples/04_training/06_train_icemix_model.py script on any kind of data that uses a different structure than the assumed (e.g. IceCube86).

Expected behavior The module should work for any data structure.

Full traceback

Traceback (most recent call last):
  File "/ptmp/mpp/mbranden/graphnet/playground/icemix_pretrain.py", line 225, in <module>
    main(
  File "/ptmp/mpp/mbranden/graphnet/playground/icemix_pretrain.py", line 176, in main
    model.fit(
  File "/ptmp/mpp/mbranden/graphnet/src/graphnet/models/easy_model.py", line 169, in fit
    trainer.fit(
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 538, in fit
    call._call_and_handle_interrupt(
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 46, in _call_and_handle_interrupt
    return trainer.strategy.launcher.launch(trainer_fn, *args, trainer=trainer, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/strategies/launchers/subprocess_script.py", line 105, in launch
    return function(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 574, in _fit_impl
    self._run(model, ckpt_path=ckpt_path)
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 981, in _run
    results = self._run_stage()
              ^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1023, in _run_stage
    self._run_sanity_check()
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/trainer/trainer.py", line 1052, in _run_sanity_check
    val_loop.run()
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/loops/utilities.py", line 178, in _decorator
    return loop_run(self, *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 135, in run
    self._evaluation_step(batch, batch_idx, dataloader_idx, dataloader_iter)
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/loops/evaluation_loop.py", line 396, in _evaluation_step
    output = call._call_strategy_hook(trainer, hook_name, *step_args)
             ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/trainer/call.py", line 319, in _call_strategy_hook
    output = fn(*args, **kwargs)
             ^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 410, in validation_step
    return self._forward_redirection(self.model, self.lightning_module, "validation_step", *args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 640, in __call__
    wrapper_output = wrapper_module(*args, **kwargs)
                     ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1523, in forward
    else self._run_ddp_forward(*inputs, **kwargs)
         ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/parallel/distributed.py", line 1359, in _run_ddp_forward
    return self.module(*inputs, **kwargs)  # type: ignore[index]
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/pytorch_lightning/strategies/strategy.py", line 633, in wrapped_forward
    out = method(*_args, **_kwargs)
          ^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/graphnet/src/graphnet/models/easy_model.py", line 264, in validation_step
    loss = self.shared_step(val_batch, batch_idx)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/graphnet/src/graphnet/models/standard_model.py", line 117, in shared_step
    preds = self(batch)
            ^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/graphnet/src/graphnet/models/standard_model.py", line 104, in forward
    x = self.backbone(d)
        ^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/graphnet/src/graphnet/models/gnn/icemix.py", line 132, in forward
    x = self.fourier_ext(x0, seq_length)
        ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/graphnet/src/graphnet/models/components/embedding.py", line 129, in forward
    embeddings.append(self.aux_emb(x[:, :, 5].long()))  # Auxiliary
                      ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl
    return forward_call(*args, **kwargs)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/modules/sparse.py", line 163, in forward
    return F.embedding(
           ^^^^^^^^^^^^
  File "/ptmp/mpp/mbranden/.conda/envs/graphnet2/lib/python3.11/site-packages/torch/nn/functional.py", line 2237, in embedding
    return torch.embedding(weight, input, padding_idx, scale_grad_by_freq, sparse)
           ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
IndexError: index out of range in self
pweigel commented 1 month ago

Are you trying to add more features to the encoder, or do you want to just encode the existing variables (position, time, charge, aux)? What might make sense is to store a mapping of the variables to the corresponding index in the features dimensions. This could be part of IceMixNodes, and the map can be passed to the DeepIce model and the subsequent modules that need that info.