Open turian opened 1 year ago
It would be helpful if you could provide your call to draw_graph
. For example, I've gotten it to work on nn.MultiheadAttention
, which also requires multiple inputs, in the following ways:
import torch
from torch import nn
from torchview import draw_graph
model = nn.MultiheadAttention(8, 8)
# Method 1
x = torch.randn(8, 8)
draw_graph(model, input_data=(x, x, x))
# Method 2
s = torch.Size([8, 8])
draw_graph(model, input_size=(s, s, s))
If this doesn't work for you, then I don't know the answer, unfortunately :)
@mert-kurttutan sure! Check out this colab:
https://colab.research.google.com/drive/1FFcrCrjLw7UkFDyB9P954AKxpghRkjOR?usp=sharing
I think possibly the issue is the use of the opt_einsum
library? That colab allows you to easily run and inspect.
As @snimu suggested, you could pass the explicit tensors, which will be expanded into positional arguments (in the form 2 tensors). But, there is one caveat if your function requires iterable of tensors, Say, the signature of the forward function of your is as follows
def forward(self, tensor_tuple: Tuple[torch.Tensor]):
...
return out
What you could do is the following
input_data = [(torch.rand(1, 1,16000), torch.rand(1, 1, 2))]
model_graph = draw_graph(model, input_data=input_data)
Caveat: input_data
tuple is contained inside list, this is necessary since during pre-processing input_data
if expandable, it is expanded into positional arguments. So, it puts another layer on top of tuple to prevent the expansion of tuple into individual tensors.
More generally, any type of iterable
should work , not just list
Yes, understood.
I have actually adapted the forward method to take optionally only one input (with the understanding I could follow @snimu 's approach). I am now getting a different error, which I believe is due to opt_einsum. Should I open a new issue?
---------------------------------------------------------------------------
TypeError Traceback (most recent call last)
/usr/local/lib/python3.9/dist-packages/torchview/torchview.py in forward_prop(model, x, device, model_graph, mode, **kwargs)
255 if isinstance(x, (list, tuple)):
--> 256 _ = model.to(device)(*x, **kwargs)
257 elif isinstance(x, Mapping):
26 frames
/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
145 # this seems not to be necessary so far
--> 146 out = _orig_module_forward(mod, *args, **kwargs)
147
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
/content/diffwave-sashimi/models/sashimi.py in forward(self, input_data, mel_spec)
303 outputs.append(x)
--> 304 x = layer(x, diffusion_step_embed, mel_spec=mel_spec)
305
/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
145 # this seems not to be necessary so far
--> 146 out = _orig_module_forward(mod, *args, **kwargs)
147
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
/content/diffwave-sashimi/models/sashimi.py in forward(self, x, diffusion_step_embed, mel_spec)
156 # dilated conv layer
--> 157 y, _ = self.layer(y)
158
/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
145 # this seems not to be necessary so far
--> 146 out = _orig_module_forward(mod, *args, **kwargs)
147
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
/content/diffwave-sashimi/models/s4.py in forward(self, u, rate, state, **kwargs)
1387 L_kernel = L if self.L is None else min(L, round(self.L / rate))
-> 1388 k, k_state = self.kernel(L=L_kernel, rate=rate, state=state) # (C H L) (B C H L)
1389
/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
109 if not input_nodes:
--> 110 return _orig_module_forward(mod, *args, **kwargs)
111
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
/content/diffwave-sashimi/models/s4.py in forward(self, state, L, rate)
1237 def forward(self, state=None, L=None, rate=None):
-> 1238 return self.kernel(state=state, L=L, rate=rate)
1239
/usr/local/lib/python3.9/dist-packages/torchview/recorder_tensor.py in _module_forward_wrapper(mod, *args, **kwargs)
109 if not input_nodes:
--> 110 return _orig_module_forward(mod, *args, **kwargs)
111
/usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
1193 or _global_forward_hooks or _global_forward_pre_hooks):
-> 1194 return forward_call(*input, **kwargs)
1195 # Do not call functions when jit is used
/content/diffwave-sashimi/models/s4.py in forward(self, state, rate, L)
686 if self.L.item() == 0 and self.l_max is not None and self.l_max > 0:
--> 687 self._setup_C(self.l_max)
688
/usr/local/lib/python3.9/dist-packages/torch/autograd/grad_mode.py in decorate_context(*args, **kwargs)
26 with self.clone():
---> 27 return func(*args, **kwargs)
28 return cast(F, decorate_context)
/content/diffwave-sashimi/models/s4.py in _setup_C(self, L)
544 C_ = _conj(C)
--> 545 prod = contract("h m n, c h n -> c h m", dA_L.transpose(-1, -2), C_)
546 if double_length: prod = -prod # Multiply by I + dA_L instead
/usr/local/lib/python3.9/dist-packages/opt_einsum/contract.py in contract(*operands, **kwargs)
506
--> 507 return _core_contract(operands, contraction_list, backend=backend, **einsum_kwargs)
508
/usr/local/lib/python3.9/dist-packages/opt_einsum/contract.py in _core_contract(operands, contraction_list, backend, evaluate_constants, **einsum_kwargs)
590 # Do the contraction
--> 591 new_view = _einsum(einsum_str, *tmp_operands, backend=backend, **einsum_kwargs)
592
/usr/local/lib/python3.9/dist-packages/opt_einsum/sharing.py in cached_einsum(*args, **kwargs)
150 if not currently_sharing():
--> 151 return einsum(*args, **kwargs)
152
/usr/local/lib/python3.9/dist-packages/opt_einsum/contract.py in _einsum(*operands, **kwargs)
352
--> 353 return fn(einsum_str, *operands, **kwargs)
354
/usr/local/lib/python3.9/dist-packages/numpy/core/overrides.py in einsum(*args, **kwargs)
/usr/local/lib/python3.9/dist-packages/numpy/core/einsumfunc.py in einsum(out, optimize, *operands, **kwargs)
1358 kwargs['out'] = out
-> 1359 return c_einsum(*operands, **kwargs)
1360
/usr/local/lib/python3.9/dist-packages/torch/_tensor.py in __array__(self, dtype)
955 if dtype is None:
--> 956 return self.numpy()
957 else:
TypeError: can't convert cuda:0 device type tensor to numpy. Use Tensor.cpu() to copy the tensor to host memory first.
The above exception was the direct cause of the following exception:
RuntimeError Traceback (most recent call last)
<ipython-input-10-b17eb2cc3c24> in <cell line: 13>()
11 input_size = (1, 1, 16000)
12 # Depth 2?
---> 13 model_graph = draw_graph(
14 net,
15 input_size=input_size,
/usr/local/lib/python3.9/dist-packages/torchview/torchview.py in draw_graph(model, input_data, input_size, graph_name, depth, device, dtypes, mode, strict, expand_nested, graph_dir, hide_module_functions, hide_inner_tensors, roll, show_shapes, save_graph, filename, directory, **kwargs)
218 )
219
--> 220 forward_prop(
221 model, input_recorder_tensor, device, model_graph,
222 model_mode, **kwargs_record_tensor
/usr/local/lib/python3.9/dist-packages/torchview/torchview.py in forward_prop(model, x, device, model_graph, mode, **kwargs)
262 raise ValueError("Unknown input type")
263 except Exception as e:
--> 264 raise RuntimeError(
265 "Failed to run torchgraph see error message"
266 ) from e
RuntimeError: Failed to run torchgraph see error message
Can you also link the code with the updated input_data
version ? So I can play with it.
Edit: I guess you already did that with the most recent colab link. Sorry
I am not really sure if the way you do it in the colab link is correct. I am getting the warning WARNING: sashimi input_data is not a tuple!
But, the input should be tuple right. From my guess, if you are doing the same this colab link above, it expands the tuple into positional arguments in the input signature, which is NOT correct for Sashimi input signature, IMO.
My suggestion is still to use my previous comment above: https://github.com/mert-kurttutan/torchview/issues/90#issuecomment-1497007301
Is your feature request related to a problem? Please describe.
I'm working with a model that requires two tensors as input. The second tensor is of dtype int.
Describe the solution you'd like
I'd like to be able to, optionally, specify multiple tensor shapes and their types.
Describe alternatives you've considered
I tried mocking the second int tensor in forward if it's passed only one tensor. However, I get the traceback shown below:
Screenshots / Text
Additional context
I am trying to visualize Sashimi