megvii-research / RevCol

Official Code of Paper "Reversible Column Networks" "RevColv2"
Apache License 2.0
249 stars 10 forks source link

How to export onnx model in save_memory=True? #18

Open BearCooike opened 10 months ago

BearCooike commented 10 months ago

We are trying to convert Revcol to TensorRT format, but when converting to ONNX, we found that when using save_memory=True, the conversion does not work properly. Here is our conversion test code:

import torch
from models.revcol import *
model = revcol_tiny(save_memory=True, inter_supv=False, drop_path = 0.1, num_classes=10, kernel_size = 3)

for i in range(model.num_subnet):
    getattr(model, f'subnet{str(i)}').save_memory = False

x = torch.zeros(1, 3, 224, 224)
torch.onnx.export(model, x, './weights/revcol_tiny.onnx', verbose=False, opset_version=17,
                        training=torch.onnx.TrainingMode.EVAL,
                        do_constant_folding=True,
                        input_names=['images'],
                        output_names=['output'],
                        dynamic_axes=None) 

When save_memory=True, the following error occurs:

File [d:\SoftWare\anaconda3\envs\torch\lib\site-packages\torch\onnx\utils.py:506](file:///D:/SoftWare/anaconda3/envs/torch/lib/site-packages/torch/onnx/utils.py:506), in export(model, args, f, export_params, verbose, training, input_names, output_names, operator_export_type, opset_version, do_constant_folding, dynamic_axes, keep_initializers_as_inputs, custom_opsets, export_modules_as_functions)
    188 @_beartype.beartype
    189 def export(
    190     model: Union[torch.nn.Module, torch.jit.ScriptModule, torch.jit.ScriptFunction],
   (...)
    206     export_modules_as_functions: Union[bool, Collection[Type[torch.nn.Module]]] = False,
    207 ) -> None:
    208     r"""Exports a model into ONNX format.
    209 
    210     If ``model`` is not a :class:`torch.jit.ScriptModule` nor a
   (...)
    503             All errors are subclasses of :class:`errors.OnnxExporterError`.
...
    511         '(vmap, grad, jvp, jacrev, ...), it must override the setup_context '
    512         'staticmethod. For more details, please see '
    513         'https://pytorch.org/docs/master/notes/extending.func.html')

RuntimeError: invalid unordered_map<K, T> key

If you add the following code, the export will work, but you should not be able to take advantage of the low memory footprint of Reversible Net.

for i in range(model.num_subnet):
    getattr(model, f'subnet{str(i)}').save_memory = False

Is there any relevant solution?

nightsnack commented 10 months ago

Do you use ONNX in inference? You can set save_memory = False when converting the weight, then set save_memory = True in later inference. Low memory footprint only benefits the training process.