Dragonisss / RAM

[ECCV 2024] Restore Anything with Masks: Leveraging Mask Image Modeling for Blind All-in-One Image Restoration
49 stars 2 forks source link

Issue with ONNX conversion of PromptIR model - inconsistent frames in Vapoursynth #2

Open zelenooki87 opened 3 weeks ago

zelenooki87 commented 3 weeks ago

Hi @Dragonisss

I'm really impressed with your RAM project and the PromptIR model in particular! I'm trying to use the ONNX-converted version of the model within Vapoursynth, but I'm encountering an issue where the processed frames are inconsistent, with some appearing darker or brighter than expected.

I've been trying to convert the ram_promptir_finetune.pth model to ONNX using the following script:

https://pastebin.com/HAh1iiwh

I've experimented with several approaches, including:

Different ONNX opset versions (11, 12, 17)

Using dynamic_axes for variable input sizes, as well as fixed input dimensions (128x128,256,256,... matching my input images).

Ensuring the model is in eval mode (model.eval()) before conversion.

Using a batch size of 1 during export.

Optimizing the ONNX graph with onnxsim.

Removing normalization steps during both conversion and inference.

Verifying data types (float32) and channel order (RGB/BGR) consistency throughout the pipeline.

Comparing intermediate tensor values between the original PyTorch model and the ONNX model.

Despite these efforts, the inconsistency persists. The original PyTorch model processes image sequences correctly, but the ONNX version produces these variations in brightness.

Could you please offer any insights or advice on what might be causing this issue? Is there a pre-existing ONNX version of the PromptIR model available, or perhaps some specific considerations I should be aware of when converting it?

Thank you for your time and for creating such a fantastic project!

Dragonisss commented 2 weeks ago

We are very pleased that you are interested in our model. Regarding the brightness inconsistencies you encountered when using the ONNX model in Vapoursynth, here are a few suggestions that might help you troubleshoot and resolve the issue:

  1. Precision Differences (FP32 vs. FP16/FP64): Although you have verified the data types, different libraries may handle precision differently. Ensure that the same precision is used in both the PyTorch model and the ONNX model. If FP32 is the default, make sure the entire pipeline strictly follows this standard, including input data and internal model operations.
  2. Input Frame Preprocessing: Ensure that the input images follow the same preprocessing steps as our model. You can check whether normalization and regularization are being done correctly and according to predefined methods.
  3. Operator Conversion Check: When converting the PyTorch model to an ONNX model, there might be differences in the implementation of certain layers or operators. Tools like Netron can be used to compare the network structure before and after conversion. For inputs that lead to inference issues, you can compare the intermediate output values before and after conversion using breakpoints. This can be very effective in identifying where the problem occurs.

Hope this helps! Let me know if you have any further questions.