caiyuanhao1998 / Retinexformer

"Retinexformer: One-stage Retinex-based Transformer for Low-light Image Enhancement" (ICCV 2023) & (NTIRE 2024 Challenge)
https://arxiv.org/abs/2303.06705
MIT License
828 stars 64 forks source link

ONNX conversion of FiveK model? #78

Closed zelenooki87 closed 3 months ago

zelenooki87 commented 3 months ago

Thank you to the author for this amazing project. I would like to convert the FiveK model to ONNX format for use in VapourSynth to process video clips. I have successfully converted the model with the following code, but the output image I get with the ONNX model is not the same as when I run the PyTorch model in the conda environment. Is there anything I need to correct in the following code for the conversion?

import torch
import torch.nn as nn
import torch.onnx
import yaml
from basicsr.models import create_model
from basicsr.utils.options import parse

# Load the YAML configuration file
with open('Options/RetinexFormer_FiveK.yml', 'r') as f:
    opt = yaml.safe_load(f)

# Dodajte 'is_train' ključ u konfiguraciju
opt['is_train'] = False

# Dodajte 'dist' ključ u konfiguraciju
opt['dist'] = False  

# Create the PyTorch model
model = create_model(opt).net_g

# Load the pretrained weights
checkpoint = torch.load('pretrained_weights/FiveK.pth')

try:
    model.load_state_dict(checkpoint['params'])
except:
    new_checkpoint = {}
    for k in checkpoint['params']:
        new_checkpoint['module.' + k] = checkpoint['params'][k]
    model.load_state_dict(new_checkpoint)

# Set the model to evaluation mode
model.eval()

# Define the input shape for the ONNX model
input_shape = (1, 3, 256, 256)  # Adjust the input shape as needed

# Create a dummy input tensor
dummy_input = torch.randn(input_shape).cuda()

# **Export the model to ONNX format with FP32 precision**
with torch.no_grad():
    torch.onnx.export(
        model,
        dummy_input,
        "retinexformer_fivek_Opset-17.onnx",
        opset_version=17,  # Choose an appropriate opset version
        input_names=['input'],
        output_names=['output'],
        dynamic_axes={'input': {0: 'batch_size', 2: 'height', 3: 'width'},
                      'output': {0: 'batch_size', 2: 'height', 3: 'width'}},
        export_params=True,
        # **Specify FP32 as the desired data type**
        do_constant_folding=False,
    )

Thank you very much in advance. Best regards.

zelenooki87 commented 3 months ago

It was problem cause my VS script not using sRGB color space. Everything is OK with converted model. (tested with chainner) However, during pytorch inference I am getting error: python Enhancement/test_from_dataset.py --opt Options/RetinexFormer_FiveK.yml --weights pretrained_weights/FiveK.pth --dataset FiveK export CUDA_VISIBLE_DEVICES=0 dataset FiveK Not using Automatic Mixed Precision ===>Testing using weights: pretrained_weights/FiveK.pth C:\Users\Miki\Retinexformer\input C:\Users\Miki\Retinexformer\input 0%| | 0/47 [00:00<?, ?it/s] Traceback (most recent call last): File "C:\Users\Miki\Retinexformer\Enhancement\test_from_dataset.py", line 251, in restored_1 = model_restoration(input_1, model_restoration) File "C:\Users\Miki\anaconda3\envs\cdformer\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(*args, kwargs) File "C:\Users\Miki\anaconda3\envs\cdformer\lib\site-packages\torch\nn\parallel\data_parallel.py", line 169, in forward return self.module(*inputs[0], *kwargs[0]) File "C:\Users\Miki\anaconda3\envs\cdformer\lib\site-packages\torch\nn\modules\module.py", line 1501, in _call_impl return forward_call(args, kwargs) TypeError: RetinexFormer.forward() takes 2 positional arguments but 3 were given

EDIT: only on some images.

caiyuanhao1998 commented 3 months ago

Hi, thanks for you interests. it seems that you give three augment while three augments are needed. besides, windows system is also not suggested. Please follow the readme to install the environment and run our code.