krasserm / super-resolution

Tensorflow 2.x based implementation of EDSR, WDSR and SRGAN for single image super-resolution
Apache License 2.0
1.51k stars 352 forks source link

AssertionError #91

Open otsebriy opened 2 years ago

otsebriy commented 2 years ago

AssertionError Traceback (most recent call last)

in ----> 1 model = wdsr_b(scale=scale, num_res_blocks=depth) 2 model.load_weights(weights_file) ~/Desktop/GAN-research/super-resolution/model/wdsr.py in wdsr_b(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling) 12 13 def wdsr_b(scale, num_filters=32, num_res_blocks=8, res_block_expansion=6, res_block_scaling=None): ---> 14 return wdsr(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling, res_block_b) 15 16 ~/Desktop/GAN-research/super-resolution/model/wdsr.py in wdsr(scale, num_filters, num_res_blocks, res_block_expansion, res_block_scaling, res_block) 20 21 # main branch ---> 22 m = conv2d_weightnorm(num_filters, 3, padding='same')(x) 23 for i in range(num_res_blocks): 24 m = res_block(m, num_filters, res_block_expansion, kernel_size=3, scaling=res_block_scaling) ~/Desktop/GAN-research/super-resolution/model/wdsr.py in conv2d_weightnorm(filters, kernel_size, padding, activation, **kwargs) 57 58 def conv2d_weightnorm(filters, kernel_size, padding='same', activation=None, **kwargs): ---> 59 return tfa.layers.WeightNormalization(Conv2D(filters, kernel_size, padding=padding, activation=activation, **kwargs), data_init=False) ~/.pyenv/versions/3.6.8/envs/srgan/lib/python3.6/site-packages/typeguard/__init__.py in wrapper(*args, **kwargs) 1030 memo = _CallMemo(python_func, _localns, args=args, kwargs=kwargs) 1031 check_argument_types(memo) -> 1032 retval = func(*args, **kwargs) 1033 try: 1034 check_return_type(retval, memo) ~/.pyenv/versions/3.6.8/envs/srgan/lib/python3.6/site-packages/tensorflow_addons/layers/wrappers.py in __init__(self, layer, data_init, **kwargs) 58 @typechecked 59 def __init__(self, layer: tf.keras.layers, data_init: bool = True, **kwargs): ---> 60 super().__init__(layer, **kwargs) 61 self.data_init = data_init 62 self._track_trackable(layer, name="layer") ~/.pyenv/versions/3.6.8/envs/srgan/lib/python3.6/site-packages/keras/layers/wrappers.py in __init__(self, layer, **kwargs) 44 45 def __init__(self, layer, **kwargs): ---> 46 assert isinstance(layer, Layer) 47 self.layer = layer 48 super(Wrapper, self).__init__(**kwargs) AssertionError: I wasn't able to download weights for `WDSR`, any ideas what's the problem?
adavradou commented 2 years ago

Hello @otsebriy! Same error (AssertionError) here, did you manage to find a solution? I downloaded the WDSR weights though with no problem.

adavradou commented 2 years ago

Hi! I think I managed to solve the AssertionError.

According to this https://www.reddit.com/r/tensorflow/comments/dn0hjv/applying_weight_normalization_layer_in_tf_2/, you need to modify the conv2d_weightnorm function in the following way :

def conv2d_weightnorm(filters, kernel_size, padding='same', activation=None, **kwargs): return Conv2D(filters, kernel_size, padding=padding, activation=activation, **kwargs)

For me it worked!

adavradou commented 2 years ago

Now I get a ValueError: axes don't match array error when loading the weights though...