Closed vict0rsch closed 3 years ago
check that if use spade for m: s and d are in the tasks
class MaskSpadeDecoder(nn.Module):
def __init__(self, ...):
self.mask_conv = nn.Conv2d(self.final_nc, 1, 3, padding=1)
def forward(self, z, cond):
y = self.spade_1(z, cond)
y = self.up(y)
y = self.spade_2(z, cond)
y = self.up(y)
y = self.spade_3(z, cond)
y = self.up(y)
return self.mask_conv(y)
@tianyu-z I have a slight reservation in your implementation of the forward pass in case of dlv3 resnets.
My intuition would have been to somehow concatenate and project the low and high level feature to keep a single spade line of processing. What you have implemented makes the spade block treat in 2 different parallel paths each features but I'm not sure it makes sense that it should know what kind of features it gets and how it should treat it differently. Do you see what I mean?
I'd have done something simpler like in case of resnet as I have implemented in the BaseDecoder
with the low_level_conv
and merge_feats_conv
.
If you don't agree with this path feel free to discuss as always :) !
@tianyu-z I did a quick pass to kinda fix and make uniform the overall API but not the SPADE decoder itself as i want to have your opinion
@tianyu-z Can you turn that into a
MaskSpadeDecoder
class? instead ofG.spade
?