inferno-pytorch / inferno

A utility library around PyTorch
Other
244 stars 41 forks source link

UNet: Allow to freely define the number of channels per depth in subclasses #165

Closed mys007 closed 5 years ago

mys007 commented 5 years ago

This PR proposes to replace the hard-coded definition of the number of channels 2**depth by function _get_num_channels() which can be overridden in subclasses, e.g. as return list_of_widths[depth - 1]. In addition, pre_conv_op_regularizer_factory and post_conv_op_regularizer_factory are removed, as they seem unused.

mys007 commented 5 years ago

Cool, what kind of tests do you suggest?

DerThorsten commented 5 years ago

I think we should have a custom unet in the tests which overwrites _get_num_channels(self, depth) with some corner cases, maybe a constant number of channels / only a single of channel, number of channels which decrease when we go down in the unet etc, Using the side-output one could even check if the layers indeed return the right number of channels for a certain depth.

DerThorsten commented 5 years ago

btw thanks for the PR, this was indeed some missing functionality

DerThorsten commented 5 years ago

@mys007 If you think that these additional tests are a bit overkill, we could also merge now and extend the tests later

constantinpape commented 5 years ago

LGTM from my side.

mys007 commented 5 years ago

Thanks a lot for the review. I will add the tests on Monday!

mys007 commented 5 years ago

Tests have been added.

DerThorsten commented 5 years ago

once traivs is done i'll merge it, thanks for the tests and the PR