[CVPR 2022 Oral] PyTorch re-implementation for "MAXIM: Multi-Axis MLP for Image Processing", with *training code*. Official Jax repo: https://github.com/google-research/maxim
This network occupies too much graphics memory. Usually, inputs of 4, 1, 128, and 128 will occupy up to 16GB of graphics memory. High definition images must be input in patches. How did the original JAX network handle this?
This network occupies too much graphics memory. Usually, inputs of 4, 1, 128, and 128 will occupy up to 16GB of graphics memory. High definition images must be input in patches. How did the original JAX network handle this?