google-research / maxim

[CVPR 2022 Oral] Official repository for "MAXIM: Multi-Axis MLP for Image Processing". SOTA for denoising, deblurring, deraining, dehazing, and enhancement.
https://arxiv.org/abs/2201.02973
Apache License 2.0
1.02k stars 110 forks source link

maxim - translate to keras #29

Closed innat closed 2 years ago

innat commented 2 years ago

@vztu @Yinxiaoli

I tried to translate maxim form jax to keras. All looks fine but the number of the training parameter looks abnormally large. For the following config, I got total 674,238,099 params.

H, W = 224, 224

INS, MODEL = MAXIM(
    features= 32,
    depth=3,
    num_stages=1,
    num_groups=2,
    num_bottleneck_blocks=2,
    block_gmlp_factor=2,
    grid_gmlp_factor=2,
    input_proj_factor=2,
    channels_reduction=4,

    num_supervision_scales=3,
    use_bias=True,
    lrelu_slope=0.1,
    use_global_mlp=10,
    use_cross_gating=False,
    high_res_stages=1,
    block_size_hr=[2, 2],
    block_size_lr=[2, 2],
    grid_size_hr=[2, 2],
    grid_size_lr=[2, 2],
    num_outputs=3,
    dropout_rate=0.5,
)

Could you please check the plot diagram? (in case you notice any misconnection; if you click on the image body below, it will open on new tab and it would be easy to inspect.).

m (1)

vztu commented 2 years ago

Hi thanks for the enormous efforts @innat! Could you please share a pointer to the keras code?

innat commented 2 years ago

@vztu Thanks for your response. Here is the gist.

Above, sayak-san makes a great progress. I will try to relate my trials with his to catch the mismatch.