Closed AMHermansen closed 7 months ago
Hey @AMHermansen, here is my implementation. Please let me know your thoughts. The following output was produced by running the extra_repr implementation:
GraphDefinition(
GraphDefinition(
class_name=GraphDefinition
arguments={'detector': ModelConfig(class_name='IceCube86', arguments={}), 'node_definition': ModelConfig(class_name='NodesAsPulses', arguments={'input_feature_names': None}), 'edge_definition': None, 'input_feature_names': None, 'dtype': torch.float32, 'perturbation_dict': None, 'seed': None, 'add_inactive_sensors': False, 'sensor_mask': None, 'string_mask': None, 'sort_by': None}
)
(_detector): IceCube86()
(_node_definition): NodesAsPulses()
)
The implementation used to generate the output is as follows:
def extra_repr(self) -> str:
"""Provide a more detailed description of the object print.
Returns:
str: A string representation containing detailed information about the object.
"""
full_str = f"{self.__class__.__name__}(\n"
for item, value in self._config.__dict__.items():
if isinstance(value, Model):
full_str += self._predindent_args(value)
else:
full_str += f" {item}={value}\n"
full_str += ")"
return full_str
def _predindent_args(self, model: Model) -> str:
"""Indent nested model arguments.
Args:
model (Model): The nested model.
Returns:
str: Indented string representation of the nested model's arguments.
"""
indented_str = model.extra_repr().replace("\n", "\n ")
return f" {indented_str}\n"
here is how i create the code
from graphnet.models.graphs.graph_definition import GraphDefinition
from graphnet.models.detector import IceCube86
graph_def = GraphDefinition(detector=IceCube86())
print(graph_def)
Hello @samadpls sorry for the delay. I think it is great that you've taken the time to figure out a way to print the information contained in the _config
field, and I think this is a very good starting point, to refine a bit to have it more visible at a glance, how the object was created.
Some things I notice about how the model is printed in the concrete example is that GraphDefinition
is printed 3 times, one of the times is done from the torch.nn.Module.__repr__
. Perhaps it is sufficient to print the class name and the argument? In particular I'm thinking that the line where class_name=GraphDefinition
is maybe redundant since the information is contained above that we're dealing with a GraphDefinition object.
As you're doing it right now the line with arguments
is very long and could probably be printed in a neater way especially for the ModelConfig objects. Perhaps the best approach to achieve this would be to make a good __repr__
for the ModelConfig
class. And then the extra_repr
in Model
could just print its _config
member variable?
Could you also include how a Model
object which contains learnable parameters is printed? I think one of the subclasses of Task
would be nice, so we can see how this extra_repr
works with the components from torch?
Let me know what you think?
Is your feature request related to a problem? Please describe. Currently when printing a GraphNeT Model we default to the bare-bones pytorch
module.__repr__
, which only prints the layers in a module. Since all GraphNeT Models contain a config_dict with information for how it was constructed we can provide a much more detailed description of the object print. This is especially useful for models containing no learnable parameters likeDetector
andGraphDefinition
.For example when trying to print a
GraphDefinition
object you getGraphDefinition()
regardless of how the GraphDefinition was constructed.Describe the solution you'd like Implement an
extra_repr
method in Model which provide a more detailed description of the object.