wesselb / neuralprocesses

A framework for composing Neural Processes in Python
https://wesselb.github.io/neuralprocesses
MIT License
76 stars 12 forks source link

Add UNet resize-conv layers #1

Closed tom-andersson closed 2 years ago

tom-andersson commented 2 years ago

This PR addresses the issue of checkerboard artifacts in the existing UNet class. The problem is addressed by using resize-convolution layers in place of transposed convolutions. A boolean argument allows for reverting back to transposed convolutions, which should help with assessing the impact of this change.

Have tested for the 2D case with TensorFlow, but will be worth testing 1D and 3D and a PyTorch backend as well.

Please review - happy to make any changes.

wesselb commented 2 years ago

This looks great, @tom-andersson! Would you be able to enable this these convs in test_architectures? I think it is probably best to add a new config

            (
                "construct_convgnp",
                {
                    "num_basis_functions": 4,
                    "points_per_unit": 16,
                    "unet_channels": (8, 16),
                    "unet_kernels": (5,) * 2,
                    "unet_activations": (B.relu,) * 2,
                    "unet_use_resizes_convs": True,  # This is added.
                    "epsilon": 1e-2,
                },
            ),
tom-andersson commented 2 years ago

Thank you for the feedback @wesselb ! Regarding the test config, note that I've set the boolean argument to use the resize-convs by default, so all code/tests that don't set it will by default switch to using the resize-conv layers. I've therefore added two test configs with the unet_uses_resize_convs arg set to both False and True in https://github.com/wesselb/neuralprocesses/pull/1/commits/7fdab9958d5b4c978298ef1b652fee5d9f4a1fd0. Any feedback on this?

coveralls commented 2 years ago

Pull Request Test Coverage Report for Build 2071972135


Changes Missing Coverage Covered Lines Changed/Added Lines %
neuralprocesses/tensorflow/nn.py 19 23 82.61%
<!-- Total: 32 36 88.89% -->
Totals Coverage Status
Change from base Build 2030075431: 0.3%
Covered Lines: 777
Relevant Lines: 956

💛 - Coveralls
wesselb commented 2 years ago

3.8 is failing, but that seems due to bad luck, so I'm merging this. Thanks @tom-andersson for the contribution!! :)