mert-kurttutan / torchview

torchview: visualize pytorch models
https://torchview.dev
MIT License
793 stars 36 forks source link

Possible to pass two tensors to torchview #90

Open turian opened 1 year ago

turian commented 1 year ago

Is your feature request related to a problem? Please describe.

I'm working with a model that requires two tensors as input. The second tensor is of dtype int.

Describe the solution you'd like

I'd like to be able to, optionally, specify multiple tensor shapes and their types.

Describe alternatives you've considered

I tried mocking the second int tensor in forward if it's passed only one tensor. However, I get the traceback shown below:

Screenshots / Text

Traceback (most recent call last):
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/torchview.py", line 256, in forward_prop
    _ = model.to(device)(*x, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/recorder_tensor.py", line 146, in _module_forward_wrapper
    out = _orig_module_forward(mod, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/diffwave-sashimi-sourcesep/models/sashimi.py", line 337, in forward
    x = layer(x, diffusion_step_embed, mel_spec=mel_spec)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/recorder_tensor.py", line 146, in _module_forward_wrapper
    out = _orig_module_forward(mod, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/diffwave-sashimi-sourcesep/models/sashimi.py", line 174, in forward
    y, _ = self.layer(y)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/recorder_tensor.py", line 146, in _module_forward_wrapper
    out = _orig_module_forward(mod, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/diffwave-sashimi-sourcesep/models/s4.py", line 1562, in forward
    k, k_state = self.kernel(
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/recorder_tensor.py", line 110, in _module_forward_wrapper
    return _orig_module_forward(mod, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/diffwave-sashimi-sourcesep/models/s4.py", line 1392, in forward
    return self.kernel(state=state, L=L, rate=rate)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torchview/recorder_tensor.py", line 110, in _module_forward_wrapper
    return _orig_module_forward(mod, *args, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1194, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/ubuntu/diffwave-sashimi-sourcesep/models/s4.py", line 789, in forward
    self._setup_C(self.l_max)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/autograd/grad_mode.py", line 27, in decorate_context
    return func(*args, **kwargs)
  File "/home/ubuntu/diffwave-sashimi-sourcesep/models/s4.py", line 636, in _setup_C
    prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
  File "/usr/lib/python3/dist-packages/opt_einsum/contract.py", line 507, in contract
    return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
  File "/usr/lib/python3/dist-packages/opt_einsum/contract.py", line 591, in _core_contract
    new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
  File "/usr/lib/python3/dist-packages/opt_einsum/sharing.py", line 151, in cached_einsum
    return einsum(*args, **kwargs)
  File "/usr/lib/python3/dist-packages/opt_einsum/contract.py", line 353, in _einsum
    return fn(einsum_str, *operands, **kwargs)
  File "<__array_function__ internals>", line 5, in einsum
  File "/usr/lib/python3/dist-packages/numpy/core/einsumfunc.py", line 1356, in einsum
    return c_einsum(*operands, **kwargs)
  File "/home/ubuntu/.local/lib/python3.8/site-packages/torch/_tensor.py", line 956, in __array__
    return self.numpy()
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

Additional context

I am trying to visualize Sashimi

snimu commented 1 year ago

It would be helpful if you could provide your call to draw_graph. For example, I've gotten it to work on nn.MultiheadAttention, which also requires multiple inputs, in the following ways:

import torch
from torch import nn
from torchview import draw_graph

model = nn.MultiheadAttention(8, 8)

# Method 1
x = torch.randn(8, 8)
draw_graph(model, input_data=(x, x, x))

# Method 2
s = torch.Size([8, 8])
draw_graph(model, input_size=(s, s, s))

If this doesn't work for you, then I don't know the answer, unfortunately :)

turian commented 1 year ago

@mert-kurttutan sure! Check out this colab:

https://colab.research.google.com/drive/1FFcrCrjLw7UkFDyB9P954AKxpghRkjOR?usp=sharing

I think possibly the issue is the use of the opt_einsum library? That colab allows you to easily run and inspect.

mert-kurttutan commented 1 year ago

As @snimu suggested, you could pass the explicit tensors, which will be expanded into positional arguments (in the form 2 tensors). But, there is one caveat if your function requires iterable of tensors, Say, the signature of the forward function of your is as follows

def forward(self, tensor_tuple: Tuple[torch.Tensor]):
  ...
  return out

What you could do is the following

input_data = [(torch.rand(1, 1,16000), torch.rand(1, 1, 2))]
model_graph = draw_graph(model, input_data=input_data)

Caveat: input_data tuple is contained inside list, this is necessary since during pre-processing input_data if expandable, it is expanded into positional arguments. So, it puts another layer on top of tuple to prevent the expansion of tuple into individual tensors.

More generally, any type of iterable should work , not just list

turian commented 1 year ago

Yes, understood.

I have actually adapted the forward method to take optionally only one input (with the understanding I could follow @snimu 's approach). I am now getting a different error, which I believe is due to opt_einsum. Should I open a new issue?

---------------------------------------------------------------------------
TypeError                                 Traceback (most recent call last)
/usr/local/lib/python3.9/dist-packages/torchview/torchview.py in forward_prop(model, x, device, model_graph, mode, **kwargs)
    255                 if isinstance(x, (list, tuple)):
--> 256                     _ = model.to(device)(*x, **kwargs)
    257                 elif isinstance(x, Mapping):

26 frames
/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
    145         # this seems not to be necessary so far
--> 146         out = _orig_module_forward(mod, *args, **kwargs)
    147 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used

/content/diffwave-sashimi/models/sashimi.py in forward(self, input_data, mel_spec)
    303             outputs.append(x)
--> 304             x = layer(x, diffusion_step_embed, mel_spec=mel_spec)
    305 

/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
    145         # this seems not to be necessary so far
--> 146         out = _orig_module_forward(mod, *args, **kwargs)
    147 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used

/content/diffwave-sashimi/models/sashimi.py in forward(self, x, diffusion_step_embed, mel_spec)
    156         # dilated conv layer
--> 157         y, _ = self.layer(y)
    158 

/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
    145         # this seems not to be necessary so far
--> 146         out = _orig_module_forward(mod, *args, **kwargs)
    147 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used

/content/diffwave-sashimi/models/s4.py in forward(self, u, rate, state, **kwargs)
   1387         L_kernel = L if self.L is None else min(L, round(self.L / rate))
-> 1388         k, k_state = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
   1389 

/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
    109         if not input_nodes:
--> 110             return _orig_module_forward(mod, *args, **kwargs)
    111 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used

/content/diffwave-sashimi/models/s4.py in forward(self, state, L, rate)
   1237     def forward(self, state=None, L=None, rate=None):
-> 1238         return self.kernel(state=state, L=L, rate=rate)
   1239 

/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
    109         if not input_nodes:
--> 110             return _orig_module_forward(mod, *args, **kwargs)
    111 

/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
   1193                 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194             return forward_call(*input, **kwargs)
   1195         # Do not call functions when jit is used

/content/diffwave-sashimi/models/s4.py in forward(self, state, rate, L)
    686         if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
--> 687             self._setup_C(self.l_max)
    688 

/usr/local/lib/python3.9/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
     26             with self.clone():
---> 27                 return func(*args, **kwargs)
     28         return cast(F, decorate_context)

/content/diffwave-sashimi/models/s4.py in _setup_C(self, L)
    544         C_ = _conj(C)
--> 545         prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
    546         if double_length: prod = -prod # Multiply by I + dA_L instead

/usr/local/lib/python3.9/dist-packages/opt_einsum/contract.py in contract(*operands, **kwargs)
    506 
--> 507     return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
    508 

/usr/local/lib/python3.9/dist-packages/opt_einsum/contract.py in _core_contract(operands, contraction_list, backend, evaluate_constants, **einsum_kwargs)
    590             # Do the contraction
--> 591             new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
    592 

/usr/local/lib/python3.9/dist-packages/opt_einsum/sharing.py in cached_einsum(*args, **kwargs)
    150         if not currently_sharing():
--> 151             return einsum(*args, **kwargs)
    152 

/usr/local/lib/python3.9/dist-packages/opt_einsum/contract.py in _einsum(*operands, **kwargs)
    352 
--> 353     return fn(einsum_str, *operands, **kwargs)
    354 

/usr/local/lib/python3.9/dist-packages/numpy/core/overrides.py in einsum(*args, **kwargs)

/usr/local/lib/python3.9/dist-packages/numpy/core/einsumfunc.py in einsum(out, optimize, *operands, **kwargs)
   1358             kwargs['out'] = out
-> 1359         return c_einsum(*operands, **kwargs)
   1360 

/usr/local/lib/python3.9/dist-packages/torch/_tensor.py in __array__(self, dtype)
    955         if dtype is None:
--> 956             return self.numpy()
    957         else:

TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.

The above exception was the direct cause of the following exception:

RuntimeError                              Traceback (most recent call last)
<ipython-input-10-b17eb2cc3c24> in <cell line: 13>()
     11 input_size = (1, 1, 16000)
     12 # Depth 2?
---> 13 model_graph = draw_graph(
     14     net,
     15     input_size=input_size,

/usr/local/lib/python3.9/dist-packages/torchview/torchview.py in draw_graph(model, input_data, input_size, graph_name, depth, device, dtypes, mode, strict, expand_nested, graph_dir, hide_module_functions, hide_inner_tensors, roll, show_shapes, save_graph, filename, directory, **kwargs)
    218     )
    219 
--> 220     forward_prop(
    221         model, input_recorder_tensor, device, model_graph,
    222         model_mode, **kwargs_record_tensor

/usr/local/lib/python3.9/dist-packages/torchview/torchview.py in forward_prop(model, x, device, model_graph, mode, **kwargs)
    262                     raise ValueError("Unknown input type")
    263     except Exception as e:
--> 264         raise RuntimeError(
    265             "Failed to run torchgraph see error message"
    266         ) from e

RuntimeError: Failed to run torchgraph see error message
mert-kurttutan commented 1 year ago

Can you also link the code with the updated input_data version ? So I can play with it.

Edit: I guess you already did that with the most recent colab link. Sorry

mert-kurttutan commented 1 year ago

I am not really sure if the way you do it in the colab link is correct. I am getting the warning WARNING: sashimi input_data is not a tuple!

But, the input should be tuple right. From my guess, if you are doing the same this colab link above, it expands the tuple into positional arguments in the input signature, which is NOT correct for Sashimi input signature, IMO.

My suggestion is still to use my previous comment above: https://github.com/mert-kurttutan/torchview/issues/90#issuecomment-1497007301