waleedka / hiddenlayer

Neural network graphs and training metrics for PyTorch, Tensorflow, and Keras.
MIT License
1.79k stars 266 forks source link

Argments Problem.... #29

Closed ZhihuaGao closed 5 years ago

ZhihuaGao commented 5 years ago

If my forward function has dict argment....what can I do? For example....

import torch
import torch.nn as nn
import hiddenlayer as hl
import torch.jit as jit

class TestNet(nn.Module):

    def __init__(self):
        super(TestNet, self).__init__()
        self.conv = nn.Conv2d(3, 3, 1, 1)
        self.relu = nn.ReLU(True)

    def forward(self, x , y, loss):
        if loss['flag']:
            x = self.conv(x)
        else:
            assert False
        return x
x = torch.randn(1,3,224,224)
net = TestNet()
graph = hl.build_graph(net, (x, x, {'flag': True}))

And it will got wrong info as below...............

Traceback (most recent call last):
  File "/home/gaozhihua/program/mmdetection/ignore_dir/1.py", line 27, in <module>
    graph = hl.build_graph(net, (x, x, {'flag': True}))
  File "/home/gaozhihua/program/hiddenlayer/hiddenlayer/graph.py", line 143, in build_graph
    import_graph(g, model, args)
  File "/home/gaozhihua/program/hiddenlayer/hiddenlayer/pytorch_builder.py", line 70, in import_graph
    trace, out = torch.jit.get_trace_graph(model, args)
  File "/home/gaozhihua/anaconda2/envs/open-mmlab/lib/python3.6/site-packages/torch/jit/__init__.py", line 196, in get_trace_graph
    return LegacyTracedModule(f, _force_outplace)(*args, **kwargs)
  File "/home/gaozhihua/anaconda2/envs/open-mmlab/lib/python3.6/site-packages/torch/nn/modules/module.py", line 489, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/gaozhihua/anaconda2/envs/open-mmlab/lib/python3.6/site-packages/torch/jit/__init__.py", line 242, in forward
    in_vars, in_desc = _flatten(args)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs, but got dict

What shuould I do? @waleedka @FerumFlex @ss18

waleedka commented 5 years ago

I'm afraid that's not supported yet. Sorry!