graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
85 stars 86 forks source link

Added `extra_repr` and `__repr__ ` methods to `Model` and `ModelConfig` classes #665

Closed samadpls closed 4 months ago

samadpls commented 4 months ago

Implemented extra_repr and __repr__ for GraphDefinition and ModelConfig classes Resolved issue #650 While Running Following Code

from graphnet.models.graphs.graph_definition import GraphDefinition
from graphnet.models.detector import IceCube86

graph_def = graph_definition=GraphDefinition(detector=IceCube86())
print(graph_def)

output

GraphDefinition(
  GraphDefinition(
  {
      'arguments': {
          'detector': ModelConfig(
  {

  }
  ),
          'node_definition': ModelConfig(
  {
      'input_feature_names': None,
  }
  ),
          'edge_definition': None,
          'input_feature_names': None,
          'dtype': torch.float32,
          'perturbation_dict': None,
          'seed': None,
          'add_inactive_sensors': False,
          'sensor_mask': None,
          'string_mask': None,
          'sort_by': None,
      },
  })
  (_detector): IceCube86()
  (_node_definition): NodesAsPulses()
)
samadpls commented 4 months ago

Hi @AMHermansen, Regarding:

Could you also include how a Model object which contains learnable parameters is printed? I think one of the subclasses of Task would be nice, so we can see how this extra_repr works with the components from torch?

Here is the Output from running test_model_config.py

StandardModel(
  (_graph_definition): KNNGraph(
    KNNGraph(
    {
        'arguments': {
            'detector': ModelConfig(
    {

    }
    ),
            'node_definition': ModelConfig(
    {
        'input_feature_names': None,
    }
    ),
            'input_feature_names': ['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area'],
            'dtype': torch.float32,
            'perturbation_dict': None,
            'seed': None,
            'nb_nearest_neighbours': 8,
            'columns': [0, 1, 2],
        },
    })
    (_detector): IceCubeDeepCore()
    (_edge_definition): KNNEdges()
    (_node_definition): NodesAsPulses()
  )
  (backbone): DynEdge(
    (_activation): LeakyReLU(negative_slope=0.01)
    (_conv_layers): ModuleList(
      (0): DynEdgeConv(nn=Sequential(
        (0): Linear(in_features=38, out_features=128, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
        (2): Linear(in_features=128, out_features=256, bias=True)
        (3): LeakyReLU(negative_slope=0.01)
      ))
      (1-3): 3 x DynEdgeConv(nn=Sequential(
        (0): Linear(in_features=512, out_features=336, bias=True)
        (1): LeakyReLU(negative_slope=0.01)
        (2): Linear(in_features=336, out_features=256, bias=True)
        (3): LeakyReLU(negative_slope=0.01)
      ))
    )
    (_post_processing): Sequential(
      (0): Linear(in_features=1043, out_features=336, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
      (2): Linear(in_features=336, out_features=256, bias=True)
      (3): LeakyReLU(negative_slope=0.01)
    )
    (_readout): Sequential(
      (0): Linear(in_features=1024, out_features=128, bias=True)
      (1): LeakyReLU(negative_slope=0.01)
    )
  )
  (_tasks): ModuleList(
    (0): EnergyReconstruction(
      (_loss_function): LogCoshLoss()
      (_affine): Linear(in_features=128, out_features=1, bias=True)
    )
  )
)
samadpls commented 4 months ago

Thanks for your suggestion. I think the new change with more details is great. I updated the code to match your recommendations.

AMHermansen commented 4 months ago

Hello @samadpls I just had a look at where the unit tests were failing. The cause of the failed unit tests is that the lambda functions used for constructing the model, are anonymous functions at different memory addresses.

I think cleanest way around this might be to introduce a boolean flag, as a member variable, which will toggle between the verbose-print style and the "raw" pytorch print style. This could be done with something like:

class Model(...):
    verbose_print = True

    ...

    def extra_repr(self):
        return self._extra_repr() if self.verbose_print else ""

And then move your extra_repr to _extra_repr. Then for the troublesome unit test (tests/utilities/test_model_config.py), you should toggle verbose_print off i.e. insert:

model.verbose_print = False
constructed_model.verbose_print = False

above the assert statement assert repr(model) == repr(constructed_model)

samadpls commented 4 months ago

Hi @AMHermansen, thanks for the suggestion. However, the test keeps failing even after making these changes until I modify the assert. from

assert repr(model) == repr(constructed_model)

to

assert constructed_model.extra_repr() == model.extra_repr()
AMHermansen commented 4 months ago

Hi @AMHermansen, thanks for the suggestion. However, the test keeps failing even after making these changes until I modify the assert. from

assert repr(model) == repr(constructed_model)

to

assert constructed_model.extra_repr() == model.extra_repr()

I believe this new assert would pass trivially since both extra_repr return the empty string. I suspect the root cause, is that the verbose_print flag is not being applied recursively, so only the outermost Model object changes the flag. I think the easiest way to fix this, is to implement a utility function, which recursively changes the verbose_print flag, I think something like,

class Model(...):
    verbose_print = True

    def set_verbose_print_recursively(self, verbose_print: bool):
        for module in self.modules:
            if isinstance(module, Model):
                module.verbose_print = verbose_print
        self.verbose_print = verbose_print

Should work.

AMHermansen commented 4 months ago

Looks good, I'll merge it :raised_hands:

samadpls commented 4 months ago

thanks :)