szagoruyko / pytorchviz

A small package to create visualizations of PyTorch execution graphs
MIT License
3.18k stars 279 forks source link

Plotting model #74

Open dannyhow12 opened 2 years ago

dannyhow12 commented 2 years ago

Hi! Thanks for the great work!

I am planning to visualize my model, where the model used is UNet.

model = UNet(
    spatial_dims=3,
    in_channels=1,
    out_channels=3,
    channels=(16, 32, 64, 128, 256),
    strides=(2, 2, 2, 2),
    num_res_units=2,
    norm=Norm.BATCH,
)

make_dot(model.mean(), params=dict(model.named_parameters()))

This prompts an error that says UNet has no attribute mean. May I ask for further clarification on this and if I defined it correctly?

Regards, Danny

nihirv commented 2 years ago

You need to call your model with some input, and then pass the output of your model to make_dot.

Not sure what your UNet implementation looks like, but the following might work.

e.g.

model = UNet(...)
output = model(image)
make_dot(output, params=dict(model.named_parameters()))

Perhaps do output.mean() instead of output in the make_dot function if the above throws an error