cc-ai / climategan

Code and pre-trained model for the algorithm generating visualisations of 3 climate change related events: floods, wildfires and smog.
https://thisclimatedoesnotexist.com
GNU General Public License v3.0
75 stars 18 forks source link

tianyu dev: spade masker #182

Closed vict0rsch closed 3 years ago

vict0rsch commented 3 years ago

@tianyu-z Can you turn that into a MaskSpadeDecoder class? instead of G.spade?

vict0rsch commented 3 years ago

check that if use spade for m: s and d are in the tasks

vict0rsch commented 3 years ago
class MaskSpadeDecoder(nn.Module):

    def __init__(self, ...):

        self.mask_conv = nn.Conv2d(self.final_nc, 1, 3, padding=1)

    def forward(self, z, cond):
        y = self.spade_1(z, cond)
        y = self.up(y)
        y = self.spade_2(z, cond)
        y = self.up(y)
        y = self.spade_3(z, cond)
        y = self.up(y)
        return self.mask_conv(y)
vict0rsch commented 3 years ago

@tianyu-z I have a slight reservation in your implementation of the forward pass in case of dlv3 resnets.

My intuition would have been to somehow concatenate and project the low and high level feature to keep a single spade line of processing. What you have implemented makes the spade block treat in 2 different parallel paths each features but I'm not sure it makes sense that it should know what kind of features it gets and how it should treat it differently. Do you see what I mean?

I'd have done something simpler like in case of resnet as I have implemented in the BaseDecoder with the low_level_conv and merge_feats_conv.

If you don't agree with this path feel free to discuss as always :) !

vict0rsch commented 3 years ago

@tianyu-z I did a quick pass to kinda fix and make uniform the overall API but not the SPADE decoder itself as i want to have your opinion