quark0 / darts

Differentiable architecture search for convolutional and recurrent networks
https://arxiv.org/abs/1806.09055
Apache License 2.0
3.92k stars 843 forks source link

How to convert to ONNX. #147

Closed crook52 closed 4 years ago

crook52 commented 4 years ago

Hi. I want to convert searched darts model to TFlite, finaly. First of all, I tried to convert it to ONNX by below code.

import torch
import torch.nn as nn
import genotypes
from model import NetworkCIFAR as Network

genotype = eval("genotypes.%s" % 'DARTS')
model = Network(36, 10, 20, True, genotype)
model.load_state_dict(torch.load('./weights.pt'))
model = model.cuda()

onnx_model_path = './darts_model.onnx'
dummy_input = torch.randn(8,3,32,32)
input_names = ['image_array']
output_names = ['category']
torch.onnx.export(model,dummy_input, onnx_model_path,
                  input_names=input_names, output_names=output_names)

However, it couldn't convert. Error is below.

 Traceback (most recent call last):
  File "<stdin>", line 2, in <module>
  File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/__init__.py", line 168, in export
    custom_opsets, enable_onnx_checker, use_external_data_format)
  File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/utils.py", line 69, in export
    use_external_data_format=use_external_data_format)
  File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/utils.py", line 488, in _export
    fixed_batch_size=fixed_batch_size)
  File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/utils.py", line 334, in _model_to_graph
    graph, torch_out = _trace_and_get_graph_from_model(model, args, training)
  File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/onnx/utils.py", line 291, in _trace_and_get_graph_from_model
    torch.jit._get_trace_graph(model, args, _force_outplace=False, _return_inputs_states=True)
  File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/jit/__init__.py", line 278, in _get_trace_graph
    outs = ONNXTracedModule(f, _force_outplace, return_inputs, _return_inputs_states)(*args, **kwargs)
  File "/home/XXXX_darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/nn/modules/module.py", line 550, in __call__
    result = self.forward(*input, **kwargs)
  File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/jit/__init__.py", line 361, in forward
    self._force_outplace,
  File "/home/XXXX/darts/cnn/eval-EXP-20200710-150423/venv/lib/python3.6/site-packages/torch/jit/__init__.py", line 351, in wrapper
    out_vars, _ = _flatten(outs)
RuntimeError: Only tuples, lists and Variables supported as JIT inputs/outputs. Dictionaries and strings are also accepted but their usage is not recommended. But got unsupported type NoneType

Could you tell me how to convert it? Thank you!

My environment

crook52 commented 4 years ago

This issue has been resolved in PyTorch forum. Thank you!

Sun2018421 commented 2 years ago

This issue has been resolved in PyTorch forum. Thank you!

When I want to convert the model to Onnx, an error occurs: File "/home/huawei/sxf_workdir/darts/cnn/model.py", line 150, in forward s0, s1 = s1, cell(s0, s1, self.drop_path_prob) File "/home/huawei/anaconda3/envs/darts/lib/python3.7/site-packages/torch/nn/modules/module.py", line 594, in getattr type(self).name, name)) AttributeError: 'NetworkCIFAR' object has no attribute 'drop_path_prob' Have u encountered this problem before, Thank you!