If my forward function has dict argment....what can I do?
For example....
import torch
import torch.nn as nn
import hiddenlayer as hl
import torch.jit as jit
class TestNet(nn.Module):
def __init__(self):
super(TestNet, self).__init__()
self.conv = nn.Conv2d(3, 3, 1, 1)
self.relu = nn.ReLU(True)
def forward(self, x , y, loss):
if loss['flag']:
x = self.conv(x)
else:
assert False
return x
x = torch.randn(1,3,224,224)
net = TestNet()
graph = hl.build_graph(net, (x, x, {'flag': True}))
And it will got wrong info as below...............
Traceback (most recent call last):
File "/home/gaozhihua/program/mmdetection/ignore_dir/1.py", line 27, in <module>
graph = hl.build_graph(net, (x, x, {'flag': True}))
File "/home/gaozhihua/program/hiddenlayer/hiddenlayer/graph.py", line 143, in build_graph
import_graph(g, model, args)
File "/home/gaozhihua/program/hiddenlayer/hiddenlayer/pytorch_builder.py", line 70, in import_graph
trace, out = torch.jit.get_trace_graph(model, args)
File "/home/gaozhihua/anaconda2/envs/open-mmlab/lib/python3.6/site-packages/torch/jit/__init__.py", line 196, in get_trace_graph
return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
File "/home/gaozhihua/anaconda2/envs/open-mmlab/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
result = self.forward(*input, **kwargs)
File "/home/gaozhihua/anaconda2/envs/open-mmlab/lib/python3.6/site-packages/torch/jit/__init__.py", line 242, in forward
in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got dict
If my forward function has dict argment....what can I do? For example....
And it will got wrong info as below...............
What shuould I do? @waleedka @FerumFlex @ss18