waleedka / hiddenlayer

Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.
MIT License
1.79k stars 266 forks source link

Support dict input? #44

Open zhangjing1997 opened 5 years ago

zhangjing1997 commented 5 years ago

I got a runtime error when I tried to plot a model graph whose input x actually is a dict.

import hiddenlayer as hl

x = getBuff(0)
hl.build_graph(net, x)

So I'm wondering can this plot tool support dict input.

Ridhwanluthra commented 5 years ago

That is an issue with the backend this uses of pytorch.jit.get_trace_graph which does not support dict input. But these days more and more models are having dicts and this hl module is amazing for analysis maybe we should figure out a way to add this capability into this package. I can help with that.

SnowRipple commented 5 years ago

I have the same problem with dict input:

File "/home/snow_ripple/workspace/01_detection/mmdet/apis/inference.py", line 96, in _inference_single hl_graph = hl.build_graph(model, **data) TypeError: build_graph() got an unexpected keyword argument 'img'

Is there a workaround solution?

zhangjing1997 commented 5 years ago

That is an issue with the backend this uses of pytorch.jit.get_trace_graph which does not support dict input. But these days more and more models are having dicts and this hl module is amazing for analysis maybe we should figure out a way to add this capability into this package. I can help with that.

Yeah. I also do think a model graph tool supporting dict is very helpful, especially for this hl module, because it definitely helps understand and present large network visually. Looking forward to your contribution. Thanks!

zhangjing1997 commented 5 years ago

I have the same problem with dict input:

File "/home/snow_ripple/workspace/01_detection/mmdet/apis/inference.py", line 96, in _inference_single hl_graph = hl.build_graph(model, **data) TypeError: build_graph() got an unexpected keyword argument 'img'

Is there a workaround solution?

Maybe you can try the code like the following:

test_buff = getBuff(0)

graph = make_dot(net(test_buff), params=dict(net.named_parameters()))
graph.format = 'pdf'
graph.render("visPDF")

This interface of plotting model seems to simply get the output after feeding into the network. So I just guess it may be helpful to you. BTW, credit to the AlexNet example on https://github.com/szagoruyko/pytorchviz/blob/master/examples.ipynb.

SnowRipple commented 5 years ago

Thanks @zhangjing1997 !

What is the point of getBuff(0)? Which module have the definition of this function?

I assume it is just a placeholder so I modified it to: graph = make_dot(model(torch.zeros([1, 3, 224, 224])), params=dict(model.named_parameters()))

But it is still complaining about hte lack of other dictionary arguments.

zhangjing1997 commented 5 years ago

Thanks @zhangjing1997 !

What is the point of getBuff(0)? Which module have the definition of this function?

I assume it is just a placeholder so I modified it to: graph = make_dot(model(torch.zeros([1, 3, 224, 224])), params=dict(model.named_parameters()))

But it is still complaining about hte lack of other dictionary arguments.

In my project, getBuff(0) is just a function returning a dict as the input to the network. I think your code is good to run. But actually your model input seems to be a torch tensor, not input. If that's the case, I guess you don't need to use the previous way I mentioned and maybe you could refer to

import torch.onnx
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model.onnx")

Otherwise, if your model needs a dict input, can you show the specific error.

manesioz commented 4 years ago

Hey, any progress on this? Anyone making a PR?