waleedka / hiddenlayer

Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.
MIT License
1.79k stars 266 forks source link

Shape information is not plotted #65

Open andraspalffy opened 4 years ago

andraspalffy commented 4 years ago
    def plot_network(self, model, hltrans, input_size = [1, 1, 36]):

        graph = hl.build_graph(model, torch.zeros(input_size).double())
        #model = torchvision.models.vgg16()
        #graph = hl.build_graph(model, torch.zeros([1, 3, 224, 224]))

        dot=graph.build_dot()
        dot.format="png"
        im=dot.render(cleanup=True)
        net_img=plt.imread(im)

I have a custom pytorch model, which is plotted with the function above. If I use my model, the shape information is not plotted on the lines connecting the blocks. If I uncomment the vgg16 line, and overwrite the model, the plot contains the shape information.

Can you help me pinpoint the difference? Here is the two output images (cropped).

image

Digraph gv

peterqtr11 commented 3 years ago

@paland3 : have you solved this problem? if yes, could you please share the way you solve. @waleedka can you please help us

HaoLuo2627 commented 3 years ago

This solution works for me. I hope this helps.