Closed marisancans closed 9 months ago
I downloaded weights from maxim repo Then in maxim-pytorch repo i run jax2torch.py file with
python maxim_pytorch/jax2torch.py -c maxim_ckpt_Enhancement_FiveK_checkpoint.npz
It works and I get torch_weight.pth fileI then try to load it but im unable to understand if Im giving the wrong arguments or your code is wrong
from maxim_pytorch.maxim_torch import MAXIM_dns_3s import torch import cv2 import numpy as np from torchvision import transforms from pathlib import Path # These params are from https://github.com/google-research/maxim/blob/3c8265171ffccc80c3c9124844aef0d381609956/maxim/models/maxim.py#L910 s2 = { "features": 32, "depth": 3, "num_stages": 2, # "num_groups": 2, # "num_bottleneck_blocks": 2, # "block_gmlp_factor": 2, "grid_gmlp_factor": 2, "input_proj_factor": 2, "channels_reduction": 4, } model = MAXIM_dns_3s(features=32, depth=3, block_gmlp_factor=2, grid_gmlp_factor=2, input_proj_factor=2, channels_reduction=4, num_supervision_scales=2) state = torch.load("torch_weight.pth") model.load_state_dict(state) model.eval()
I get error:
RuntimeError: Error(s) in loading state_dict for MAXIM_dns_3s: Unexpected key(s) in state_dict: "stage_1_output_conv_0.bias", "stage_1_output_conv_0.weight", "stage_1_output_conv_1.bias", "stage_1_output_conv_1.weight", "stage_1_output_conv_2.bias", "stage_1_output_conv_2.weight".
I met the same question. Have you addressed it?
No, this repo is dead. We reverse engineered the code and implemented it ourselves in pytorch. Got really bad results on large resolutions and wont use this in the future
I downloaded weights from maxim repo Then in maxim-pytorch repo i run jax2torch.py file with
python maxim_pytorch/jax2torch.py -c maxim_ckpt_Enhancement_FiveK_checkpoint.npz
It works and I get torch_weight.pth fileI then try to load it but im unable to understand if Im giving the wrong arguments or your code is wrong
I get error: