Open rachtibat opened 1 year ago
Hey Reduan,
thanks for the issue as always!
I think having a way to use the AvgPool2d gradient for MaxPool2d layers is a must-have. I have some proof-of-concept code which I implemented back in the day to directly and explicitly compute the avg-pool gradient with MaxPool parameters using transposed convolutions.
While going over your code and seeing the BasicHook.backward
structure copied, I had the idea that we could also add a layer of abstraction above ParamMod
: a ModuleMod
or FuncMod
, which is a general modifier of the forward function.
This way, one could add very flexible custom rules based on BasicHook
, not only limited to the parameters of the module, which would be especially useful for parameter-less modules like MaxPool.
I have a different approach of attributing MaxPool in the pipeline, which could benefit from this approach. Do you maybe know of another use-case for arbitrary function override? Or maybe @sebastian-lapuschkin ?
If it is only for MaxPool, implementing an explicit rule based on Hook
may be better, where we could instead use my existing proof-of-concept code. Although, and I guess that's why you based this off BasicHook
rather than Hook
, stabilizer
would not automatically be part of the rule, which I think may not be necessary for pooling anyway.
As for the name, maybe its better to call it something like AvgPoolRule
, since for AvgPool
this would also be correct, although one could just use the EpsilonRule
there.
Hey,
thank you for your prompt and thoughtful response as always.
I like the idea to add a FuncMod
.
I ask Sebastian, and he told me that another use-case would be to change the 1x1 CNN downsample layer with stride=2 in ResNets that also creates such a checkerboard pattern. See:
The question is, if we should implement it with a FuncMod
.
A spontaneous idea that would change the backward pass function instead:
With a FuncMod
we could do:
Best
Hey,
we'd like to add a new rule that smooths the MaxPool2D operation by replacing it by an AveragePool2D backward pass:
You can test the code with
Do you think that's fine? I can create a pull request if you want.
Best, Reduan