google-research / maxim

[CVPR 2022 Oral] Official repository for "MAXIM: Multi-Axis MLP for Image Processing". SOTA for denoising, deblurring, deraining, dehazing, and enhancement.
https://arxiv.org/abs/2201.02973
Apache License 2.0
977 stars 105 forks source link

flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2". #16

Closed zarmondo11 closed 2 years ago

zarmondo11 commented 2 years ago

Hello. I try to run this rep on google colab and it works fine with Enhancement pre-trained model but when I want to load Deblurring model and use predict() to get output, this error appears...

MODEL_PATH = "Deblurring/GoPro/checkpoint.npz"
FLAGS = DummyFlags(ckpt_path = MODEL_PATH, task = "Enhancement") 
params = get_params(FLAGS.ckpt_path)
model = build_model()

import requests
from io import BytesIO

url = 'https://replicate.com/api/models/google-research/maxim/files/6707a57f-4957-4047-b020-2160aed1d27a/1fromGOPR0950.png'
image_bytes = BytesIO(requests.get(url).content)

result = predict(image_bytes)

f, ax = plt.subplots(1,2, figsize = (35,20))

ax[0].imshow(np.array(Image.open(image_bytes)))
ax[1].imshow(result) 

ax[0].set_title("Original Image")
ax[1].set_title("Enhanced Image")

plt.show()

UnfilteredStackTrace Traceback (most recent call last) in () 8 ----> 9 result = predict(image_bytes) 10

18 frames UnfilteredStackTrace: flax.errors.ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2". (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamNotFoundError)

The stack trace below excludes JAX-internal frames. The preceding is the original exception that occurred, unmodified.


The above exception was the direct cause of the following exception:

ScopeParamNotFoundError Traceback (most recent call last) /usr/local/lib/python3.7/dist-packages/flax/linen/linear.py in call(self, inputs) 356 357 kernel = self.param('kernel', self.kernel_init, kernel_shape, --> 358 self.param_dtype) 359 kernel = jnp.asarray(kernel, self.dtype) 360

ScopeParamNotFoundError: No parameter named "kernel" exists in "/stage_1_output_conv_2". (https://flax.readthedocs.io/en/latest/flax.errors.html#flax.errors.ScopeParamNotFoundError)

what's the problem? is it with pre-trained models? if yes, how can I fix it or make my own model?

vztu commented 2 years ago

Hi please revise the following part:

model = build_model(task="Deblurring")

because the default model if for Dehazing:

def build_model(task = "Dehazing"):
  model_mod = importlib.import_module(f'maxim.models.{_MODEL_FILENAME}')
  model_configs = ml_collections.ConfigDict(_MODEL_CONFIGS)

  model_configs.variant = _MODEL_VARIANT_DICT[task]

  model = model_mod.Model(**model_configs)
  return model
zarmondo11 commented 2 years ago

thanks a lot. <3