vztu / maxim-pytorch

[CVPR 2022 Oral] PyTorch re-implementation for "MAXIM: Multi-Axis MLP for Image Processing", with *training code*. Official Jax repo: https://github.com/google-research/maxim
Apache License 2.0
166 stars 22 forks source link

How to load pretrained Enhancement model in maxim_torch.py #9

Closed marisancans closed 9 months ago

marisancans commented 1 year ago

I downloaded weights from maxim repo Then in maxim-pytorch repo i run jax2torch.py file with

python maxim_pytorch/jax2torch.py -c maxim_ckpt_Enhancement_FiveK_checkpoint.npz It works and I get torch_weight.pth file

I then try to load it but im unable to understand if Im giving the wrong arguments or your code is wrong

from maxim_pytorch.maxim_torch import MAXIM_dns_3s
import torch
import cv2
import numpy as np
from torchvision import transforms

from pathlib import Path

# These params are from https://github.com/google-research/maxim/blob/3c8265171ffccc80c3c9124844aef0d381609956/maxim/models/maxim.py#L910
s2 = {
    "features": 32,
    "depth": 3,
    "num_stages": 2, #
    "num_groups": 2, # 
    "num_bottleneck_blocks": 2, #
    "block_gmlp_factor": 2,
    "grid_gmlp_factor": 2,
    "input_proj_factor": 2,
    "channels_reduction": 4,
}

model = MAXIM_dns_3s(features=32, depth=3, block_gmlp_factor=2, grid_gmlp_factor=2, input_proj_factor=2, channels_reduction=4, num_supervision_scales=2)
state = torch.load("torch_weight.pth")

model.load_state_dict(state)
model.eval()

I get error:

RuntimeError: Error(s) in loading state_dict for MAXIM_dns_3s:
    Unexpected key(s) in state_dict: "stage_1_output_conv_0.bias", "stage_1_output_conv_0.weight", "stage_1_output_conv_1.bias", "stage_1_output_conv_1.weight", "stage_1_output_conv_2.bias", "stage_1_output_conv_2.weight". 
wj320 commented 9 months ago

I downloaded weights from maxim repo Then in maxim-pytorch repo i run jax2torch.py file with

python maxim_pytorch/jax2torch.py -c maxim_ckpt_Enhancement_FiveK_checkpoint.npz It works and I get torch_weight.pth file

I then try to load it but im unable to understand if Im giving the wrong arguments or your code is wrong

from maxim_pytorch.maxim_torch import MAXIM_dns_3s
import torch
import cv2
import numpy as np
from torchvision import transforms

from pathlib import Path

# These params are from https://github.com/google-research/maxim/blob/3c8265171ffccc80c3c9124844aef0d381609956/maxim/models/maxim.py#L910
s2 = {
    "features": 32,
    "depth": 3,
    "num_stages": 2, #
    "num_groups": 2, # 
    "num_bottleneck_blocks": 2, #
    "block_gmlp_factor": 2,
    "grid_gmlp_factor": 2,
    "input_proj_factor": 2,
    "channels_reduction": 4,
}

model = MAXIM_dns_3s(features=32, depth=3, block_gmlp_factor=2, grid_gmlp_factor=2, input_proj_factor=2, channels_reduction=4, num_supervision_scales=2)
state = torch.load("torch_weight.pth")

model.load_state_dict(state)
model.eval()

I get error:

RuntimeError: Error(s) in loading state_dict for MAXIM_dns_3s:
  Unexpected key(s) in state_dict: "stage_1_output_conv_0.bias", "stage_1_output_conv_0.weight", "stage_1_output_conv_1.bias", "stage_1_output_conv_1.weight", "stage_1_output_conv_2.bias", "stage_1_output_conv_2.weight". 

I met the same question. Have you addressed it?

marisancans commented 9 months ago

No, this repo is dead. We reverse engineered the code and implemented it ourselves in pytorch. Got really bad results on large resolutions and wont use this in the future