graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
85 stars 86 forks source link

`predict_as_dataframe: ValueError` #654

Closed RasmusOrsoe closed 4 months ago

RasmusOrsoe commented 5 months ago

Describe the bug When dataloaders are construction with a collate function that removes events from batches, like our default collate function does, the automated inference of node-level and graph-level attributes fail.

To Reproduce Steps to reproduce the behavior:

  1. Pick any training example
  2. Overwrite the standard collate function with one that removes at least one event in the dataset.
  3. run model.predict_as_dataframe(dataloader)
  4. See error

Expected behavior The automated inference of whether the quantities in "additional_attributes" are node-level or graph-level variables should work without error.

Full traceback Please include the full error message to allow for debugging

Traceback (most recent call last):
  File "/lustre/hpc/icecube/stuttard/workspace/fridge/processing/samples/oscNext/selection/level6_GNN/L6_dynedge_predict.py", line 49, in <module>
    model.predict(
  File "/lustre/hpc/icecube/stuttard/workspace/fridge/processing/samples/oscNext/selection/level6_GNN/L6_dynedge_model.py", line 1153, in predict
    results_val = self.model.predict_as_dataframe(val_dataloader, **predict_kw)
  File "/lustre/hpc/icecube/stuttard/workspace/icetray/oscNext/ext/graphnet/src/graphnet/models/standard_model.py", line 458, in predict_as_dataframe
    data = np.concatenate(
  File "<__array_function__ internals>", line 200, in concatenate
ValueError: all the input array dimensions except for the concatenation axis must match exactly, but along dimension 0, the array at index 0 has size 999 and the array at index 1 has size 12459

Additional context

I believe the source of the error can be found in the following lines of code in predict_as_dataframe:

for batch in dataloader:
            for attr in attributes:
                attribute = batch[attr]
                if isinstance(attribute, torch.Tensor):
                    attribute = attribute.detach().cpu().numpy()

                # Check if node level predictions
                # If true, additional attributes are repeated
                # to make dimensions fit
                if len(predictions) != len(dataloader.dataset):
                    if len(attribute) < np.sum(
                        batch.n_pulses.detach().cpu().numpy()
                    ):
                        attribute = np.repeat(
                            attribute, batch.n_pulses.detach().cpu().numpy()
                        )
                        try:
                            assert len(attribute) == len(batch.x)
                        except AssertionError:
                            self.warning_once(
                                "Could not automatically adjust length"
                                f"of additional attribute {attr} to match length of"
                                f"predictions. Make sure {attr} is a graph-level or"
                                "node-level attribute. Attribute skipped."
                            )
                            pass
                attributes[attr].extend(attribute)

Specifially the check if len(predictions) != len(dataloader.dataset) is troublesome here, because if a collate_fn have removed one or more events from the batch, this check will return True, resulting in the error.