Closed samadpls closed 4 months ago
Hi @AMHermansen, Regarding:
Could you also include how a
Model
object which contains learnable parameters is printed? I think one of the subclasses ofTask
would be nice, so we can see how thisextra_repr
works with the components from torch?
Here is the Output from running test_model_config.py
StandardModel(
(_graph_definition): KNNGraph(
KNNGraph(
{
'arguments': {
'detector': ModelConfig(
{
}
),
'node_definition': ModelConfig(
{
'input_feature_names': None,
}
),
'input_feature_names': ['dom_x', 'dom_y', 'dom_z', 'dom_time', 'charge', 'rde', 'pmt_area'],
'dtype': torch.float32,
'perturbation_dict': None,
'seed': None,
'nb_nearest_neighbours': 8,
'columns': [0, 1, 2],
},
})
(_detector): IceCubeDeepCore()
(_edge_definition): KNNEdges()
(_node_definition): NodesAsPulses()
)
(backbone): DynEdge(
(_activation): LeakyReLU(negative_slope=0.01)
(_conv_layers): ModuleList(
(0): DynEdgeConv(nn=Sequential(
(0): Linear(in_features=38, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=128, out_features=256, bias=True)
(3): LeakyReLU(negative_slope=0.01)
))
(1-3): 3 x DynEdgeConv(nn=Sequential(
(0): Linear(in_features=512, out_features=336, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=336, out_features=256, bias=True)
(3): LeakyReLU(negative_slope=0.01)
))
)
(_post_processing): Sequential(
(0): Linear(in_features=1043, out_features=336, bias=True)
(1): LeakyReLU(negative_slope=0.01)
(2): Linear(in_features=336, out_features=256, bias=True)
(3): LeakyReLU(negative_slope=0.01)
)
(_readout): Sequential(
(0): Linear(in_features=1024, out_features=128, bias=True)
(1): LeakyReLU(negative_slope=0.01)
)
)
(_tasks): ModuleList(
(0): EnergyReconstruction(
(_loss_function): LogCoshLoss()
(_affine): Linear(in_features=128, out_features=1, bias=True)
)
)
)
Thanks for your suggestion. I think the new change with more details is great. I updated the code to match your recommendations.
Hello @samadpls I just had a look at where the unit tests were failing. The cause of the failed unit tests is that the lambda functions used for constructing the model, are anonymous functions at different memory addresses.
I think cleanest way around this might be to introduce a boolean flag, as a member variable, which will toggle between the verbose-print style and the "raw" pytorch print style. This could be done with something like:
class Model(...):
verbose_print = True
...
def extra_repr(self):
return self._extra_repr() if self.verbose_print else ""
And then move your extra_repr
to _extra_repr
.
Then for the troublesome unit test (tests/utilities/test_model_config.py
), you should toggle verbose_print off i.e. insert:
model.verbose_print = False
constructed_model.verbose_print = False
above the assert statement assert repr(model) == repr(constructed_model)
Hi @AMHermansen, thanks for the suggestion. However, the test keeps failing even after making these changes until I modify the assert. from
assert repr(model) == repr(constructed_model)
to
assert constructed_model.extra_repr() == model.extra_repr()
Hi @AMHermansen, thanks for the suggestion. However, the test keeps failing even after making these changes until I modify the assert. from
assert repr(model) == repr(constructed_model)
to
assert constructed_model.extra_repr() == model.extra_repr()
I believe this new assert would pass trivially since both extra_repr
return the empty string.
I suspect the root cause, is that the verbose_print
flag is not being applied recursively, so only the outermost Model
object changes the flag.
I think the easiest way to fix this, is to implement a utility function, which recursively changes the verbose_print
flag, I think something like,
class Model(...):
verbose_print = True
def set_verbose_print_recursively(self, verbose_print: bool):
for module in self.modules:
if isinstance(module, Model):
module.verbose_print = verbose_print
self.verbose_print = verbose_print
Should work.
Looks good, I'll merge it :raised_hands:
thanks :)
Implemented
extra_repr
and__repr__
for GraphDefinition and ModelConfig classes Resolved issue #650 While Running Following Codeoutput