icandle / CAMixerSR

CAMixerSR: Only Details Need More “Attention” (CVPR 2024)
https://arxiv.org/abs/2402.19289
Apache License 2.0
225 stars 13 forks source link

onnx export #25

Open ws-zhongm opened 4 months ago

ws-zhongm commented 4 months ago

你好,我导出onnx时,发现耗时很久,

并且在使用netron可视化时,提示 "This large graph layout might take a very long time to complete." 看起来似乎是一张很大的图, 使用的代码如下:

import argparse
import os 
import torch
import archs.CAMixerSR_arch as arch

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument(
        '--model_path',
        type=str,
        default='../../pretrained_models/LargeSR/CAMixerSR_S.pth'  # noqa: E501
    )
    parser.add_argument('--output', type=str, default=None, help='output ONNX model file')
    args = parser.parse_args()

    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    # set up model
    model = arch.CAMixerSR(n_feats=36, scale=4) # model-S:36 / model-M:48 / model-L:60
    model.load_state_dict(torch.load(args.model_path)['params_ema'], strict=True)
    model.eval()
    model = model.to(device)

    if args.output is None:
        output_file = os.path.splitext(args.model_path)[0] + '.onnx'
    else:
        output_file = args.output

    dummy_input = torch.randn(1, 3, 64, 64).to(device)  # Adjust the size as needed

    # Export the model to ONNX
    torch.onnx.export(
        model,                       # model being run
        dummy_input,                 # model input (or a tuple for multiple inputs)
        output_file,                 # where to save the model (can be a file or file-like object)
        export_params=True,          # store the trained parameter weights inside the model file
        opset_version=16,            # the ONNX version to export the model to
        do_constant_folding=True,    # whether to execute constant folding for optimization
        input_names=['input'],       # the model's input names
        output_names=['output'],     # the model's output names
        # dynamic_axes={
        #     'input': {0: 'batch_size', 2: 'height', 3: 'width'},
        #     'output': {0: 'batch_size', 2: 'height', 3: 'width'}
        #     }  # variable length axes
    )
    print(f"Model has been converted to ONNX and saved to {output_file}")

if __name__ == '__main__':
    main()