waleedka / hiddenlayer

Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.
MIT License
1.79k stars 266 forks source link

Recursively detect frame work when multiple inheritance of nn.Module is used #58

Open Damming opened 4 years ago

Damming commented 4 years ago

Hi,

I found that the function detect_framework() in graph.py can only detect the class directly inherited from nn.Module (in a Pytorch style), so I wrote this:

def find_root_class(outer_class, all_classes): if outer_class.__bases__[0] == object: return all_classes + outer_class.__bases__ else: return find_root_class(outer_class.__bases__[0], all_classes + outer_class.__bases__)

Then the original detect_framework() could be:

def detect_framework(value): classes = find_root_class(value.__class__, ()) for c in classes: if c.__module__.startswith("torch"): return "torch" elif c.__module__.startswith("tensorflow"): return "tensorflow"

Hope this could be useful.

maxfrei750 commented 4 years ago

A more elegant solution has already been implemented in 294f8732b271cbdd6310c55bdf5ce855cbf61c75. However, it has not been merged yet (and unfortunately, probably is not going to get merged any time soon). Nevertheless, you can already use the fixed version, if you install hiddenalyer via git and not via pip.