graphnet-team / graphnet

A Deep learning library for neutrino telescopes
https://graphnet-team.github.io/graphnet/
Apache License 2.0
90 stars 92 forks source link

Implement more descriptive extra_repr methods for GraphNeT Models #650

Closed AMHermansen closed 7 months ago

AMHermansen commented 9 months ago

Is your feature request related to a problem? Please describe. Currently when printing a GraphNeT Model we default to the bare-bones pytorch module.__repr__, which only prints the layers in a module. Since all GraphNeT Models contain a config_dict with information for how it was constructed we can provide a much more detailed description of the object print. This is especially useful for models containing no learnable parameters like Detector and GraphDefinition.

For example when trying to print a GraphDefinition object you get GraphDefinition() regardless of how the GraphDefinition was constructed.

Describe the solution you'd like Implement an extra_repr method in Model which provide a more detailed description of the object.

samadpls commented 8 months ago

Hey @AMHermansen, here is my implementation. Please let me know your thoughts. The following output was produced by running the extra_repr implementation:

GraphDefinition(
  GraphDefinition(
      class_name=GraphDefinition
      arguments={'detector': ModelConfig(class_name='IceCube86', arguments={}), 'node_definition': ModelConfig(class_name='NodesAsPulses', arguments={'input_feature_names': None}), 'edge_definition': None, 'input_feature_names': None, 'dtype': torch.float32, 'perturbation_dict': None, 'seed': None, 'add_inactive_sensors': False, 'sensor_mask': None, 'string_mask': None, 'sort_by': None}
  )
  (_detector): IceCube86()
  (_node_definition): NodesAsPulses()
)

The implementation used to generate the output is as follows:

def extra_repr(self) -> str:
    """Provide a more detailed description of the object print.

    Returns:
        str: A string representation containing detailed information about the object.
    """
    full_str = f"{self.__class__.__name__}(\n"
    for item, value in self._config.__dict__.items():
        if isinstance(value, Model):
            full_str += self._predindent_args(value)
        else:
            full_str += f"    {item}={value}\n"
    full_str += ")"
    return full_str

def _predindent_args(self, model: Model) -> str:
    """Indent nested model arguments.

    Args:
        model (Model): The nested model.

    Returns:
        str: Indented string representation of the nested model's arguments.
    """
    indented_str = model.extra_repr().replace("\n", "\n    ")
    return f"    {indented_str}\n"

here is how i create the code

from graphnet.models.graphs.graph_definition import GraphDefinition
from graphnet.models.detector import IceCube86

graph_def = GraphDefinition(detector=IceCube86())
print(graph_def)
AMHermansen commented 7 months ago

Hello @samadpls sorry for the delay. I think it is great that you've taken the time to figure out a way to print the information contained in the _config field, and I think this is a very good starting point, to refine a bit to have it more visible at a glance, how the object was created.

Some things I notice about how the model is printed in the concrete example is that GraphDefinition is printed 3 times, one of the times is done from the torch.nn.Module.__repr__. Perhaps it is sufficient to print the class name and the argument? In particular I'm thinking that the line where class_name=GraphDefinition is maybe redundant since the information is contained above that we're dealing with a GraphDefinition object.

As you're doing it right now the line with arguments is very long and could probably be printed in a neater way especially for the ModelConfig objects. Perhaps the best approach to achieve this would be to make a good __repr__ for the ModelConfig class. And then the extra_repr in Model could just print its _config member variable?

Could you also include how a Model object which contains learnable parameters is printed? I think one of the subclasses of Task would be nice, so we can see how this extra_repr works with the components from torch?

Let me know what you think?