Open zhangjing1997 opened 5 years ago
That is an issue with the backend this uses of pytorch.jit.get_trace_graph which does not support dict input. But these days more and more models are having dicts and this hl module is amazing for analysis maybe we should figure out a way to add this capability into this package. I can help with that.
I have the same problem with dict input:
File "/home/snow_ripple/workspace/01_detection/mmdet/apis/inference.py", line 96, in _inference_single hl_graph = hl.build_graph(model, **data) TypeError: build_graph() got an unexpected keyword argument 'img'
Is there a workaround solution?
That is an issue with the backend this uses of pytorch.jit.get_trace_graph which does not support dict input. But these days more and more models are having dicts and this hl module is amazing for analysis maybe we should figure out a way to add this capability into this package. I can help with that.
Yeah. I also do think a model graph tool supporting dict is very helpful, especially for this hl module, because it definitely helps understand and present large network visually. Looking forward to your contribution. Thanks!
I have the same problem with dict input:
File "/home/snow_ripple/workspace/01_detection/mmdet/apis/inference.py", line 96, in _inference_single hl_graph = hl.build_graph(model, **data) TypeError: build_graph() got an unexpected keyword argument 'img'
Is there a workaround solution?
Maybe you can try the code like the following:
test_buff = getBuff(0)
graph = make_dot(net(test_buff), params=dict(net.named_parameters()))
graph.format = 'pdf'
graph.render("visPDF")
This interface of plotting model seems to simply get the output after feeding into the network. So I just guess it may be helpful to you. BTW, credit to the AlexNet example on https://github.com/szagoruyko/pytorchviz/blob/master/examples.ipynb.
Thanks @zhangjing1997 !
What is the point of getBuff(0)? Which module have the definition of this function?
I assume it is just a placeholder so I modified it to: graph = make_dot(model(torch.zeros([1, 3, 224, 224])), params=dict(model.named_parameters()))
But it is still complaining about hte lack of other dictionary arguments.
Thanks @zhangjing1997 !
What is the point of getBuff(0)? Which module have the definition of this function?
I assume it is just a placeholder so I modified it to: graph = make_dot(model(torch.zeros([1, 3, 224, 224])), params=dict(model.named_parameters()))
But it is still complaining about hte lack of other dictionary arguments.
In my project, getBuff(0) is just a function returning a dict as the input to the network. I think your code is good to run. But actually your model input seems to be a torch tensor, not input. If that's the case, I guess you don't need to use the previous way I mentioned and maybe you could refer to
import torch.onnx
dummy_input = Variable(torch.randn(1, 3, 224, 224))
torch.onnx.export(model, dummy_input, "model.onnx")
Otherwise, if your model needs a dict input, can you show the specific error.
Hey, any progress on this? Anyone making a PR?
I got a runtime error when I tried to plot a model graph whose input x actually is a dict.
So I'm wondering can this plot tool support dict input.