hamzapehlivan / StyleRes

Other
70 stars 3 forks source link

Question about the SynthesisNetwork - Thank you #4

Closed jimb2834 closed 6 months ago

jimb2834 commented 6 months ago

Hello @hamzapehlivan amazing work. Any advice is greatly appreciated

A colleague of mine and I were trying to understand your SynthesisNetwork and wanted to ask you a question. We were trying to utilize it in our script which is StyleGan2 vanilla. When we used it we had un expected results and wanted to see if you could help explain it.

Your 9th layer features in the SynthesisNetwork does not seem to be the same as the original StyleGan2 features" - could you provide any information about this?

Our troubleshooting we just are reshowing you yours vs the original StyleGan below "ours"

Is there something were missing perhaps?

Script 1 "yours"

class SynthesisNetwork(torch.nn.Module):
    def __init__(self,
        w_dim,                      # Intermediate latent (W) dimensionality.
        img_resolution,             # Output image resolution.
        img_channels,               # Number of color channels.
        channel_base    = 32768,    # Overall multiplier for the number of channels.
        channel_max     = 512,      # Maximum number of channels in any layer.
        num_fp16_res    = 4,        # Use FP16 for the N highest resolutions.
        **block_kwargs,             # Arguments for SynthesisBlock.
    ):
        assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
        super().__init__()
        self.w_dim = w_dim
        self.img_resolution = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.img_channels = img_channels
        self.num_fp16_res = num_fp16_res
        self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
        fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)

        self.num_ws = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res // 2] if res > 4 else 0
            out_channels = channels_dict[res]
            use_fp16 = (res >= fp16_resolution)
            is_last = (res == self.img_resolution)
            block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
                img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
            self.num_ws += block.num_conv
            if is_last:
                self.num_ws += block.num_torgb
            setattr(self, f'b{res}', block)

    def forward(self, ws, highres_outs=None, return_f = False, return_styles=False, **block_kwargs):
        block_ws = []
        global latent_space
        with torch.autograd.profiler.record_function('split_ws'):
            if latent_space == 'w+':
                misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
                ws = ws.to(torch.float32)
            w_idx = 0
            s_idx = 0
            for res in self.block_resolutions:
                block = getattr(self, f'b{res}')
                if latent_space == 'w+':
                    block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
                else:
                    block_ws.append(ws[s_idx: s_idx + block.num_conv + block.num_torgb])
                w_idx += block.num_conv
                s_idx += block.num_conv + block.num_torgb

        x = img = None
        styles = []
        conv_idx = 0
        for res, cur_ws in zip(self.block_resolutions, block_ws):
            block = getattr(self, f'b{res}')
            x, img = block(x, img, cur_ws, highres_outs, return_f, return_styles, **block_kwargs)
            if return_f and img is None:
                return x
            if return_styles:
                styles.extend(x)

            conv_idx += block.num_conv
        if return_styles:   return styles
        return img

    def extra_repr(self):
        return ' '.join([
            f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
            f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
            f'num_fp16_res={self.num_fp16_res:d}'])

Ours which seems to not match somehow in terms of the output.

class SynthesisNetwork(torch.nn.Module):
    def __init__(self,
        w_dim,                      # Intermediate latent (W) dimensionality.
        img_resolution,             # Output image resolution.
        img_channels,               # Number of color channels.
        channel_base    = 32768,    # Overall multiplier for the number of channels.
        channel_max     = 512,      # Maximum number of channels in any layer.
        num_fp16_res    = 4,        # Use FP16 for the N highest resolutions.
        **block_kwargs,             # Arguments for SynthesisBlock.
    ):
        assert img_resolution >= 4 and img_resolution & (img_resolution - 1) == 0
        super().__init__()
        self.w_dim = w_dim
        self.img_resolution = img_resolution
        self.img_resolution_log2 = int(np.log2(img_resolution))
        self.img_channels = img_channels
        self.num_fp16_res = num_fp16_res
        self.block_resolutions = [2 ** i for i in range(2, self.img_resolution_log2 + 1)]
        channels_dict = {res: min(channel_base // res, channel_max) for res in self.block_resolutions}
        fp16_resolution = max(2 ** (self.img_resolution_log2 + 1 - num_fp16_res), 8)

        self.num_ws = 0
        for res in self.block_resolutions:
            in_channels = channels_dict[res // 2] if res > 4 else 0
            out_channels = channels_dict[res]
            use_fp16 = (res >= fp16_resolution)
            is_last = (res == self.img_resolution)
            block = SynthesisBlock(in_channels, out_channels, w_dim=w_dim, resolution=res,
                img_channels=img_channels, is_last=is_last, use_fp16=use_fp16, **block_kwargs)
            self.num_ws += block.num_conv
            if is_last:
                self.num_ws += block.num_torgb
            setattr(self, f'b{res}', block)

    def forward(self, ws, **block_kwargs):
        block_ws = []
        with torch.autograd.profiler.record_function('split_ws'):
            misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
            ws = ws.to(torch.float32)
            w_idx = 0
            for res in self.block_resolutions:
                block = getattr(self, f'b{res}')
                block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
                w_idx += block.num_conv

        x = img = None
        feats = []
        for res, cur_ws in zip(self.block_resolutions, block_ws):
            block = getattr(self, f'b{res}')
            x, img = block(x, img, cur_ws, **block_kwargs)
            feats.append(x)
        return img, feats

    # def forward_with_w_plus_new(self, ws, **block_kwargs):
    #     block_ws = []
    #     with torch.autograd.profiler.record_function('split_ws'):
    #         misc.assert_shape(ws, [None, self.num_ws, self.w_dim])
    #         ws = ws.to(torch.float32)
    #         w_idx = 0
    #         for res in self.block_resolutions:
    #             block = getattr(self, f'b{res}')
    #             block_ws.append(ws.narrow(1, w_idx, block.num_conv + block.num_torgb))
    #             w_idx += block.num_conv

    #     x = img = None
    #     feats = []
    #     for res, cur_ws in zip(self.block_resolutions, block_ws):
    #         block = getattr(self, f'b{res}')
    #         x, img = block(x, img, cur_ws, **block_kwargs)
    #         feats.append(x)
    #     return img, feats

    def extra_repr(self):
        return ' '.join([
            f'w_dim={self.w_dim:d}, num_ws={self.num_ws:d},',
            f'img_resolution={self.img_resolution:d}, img_channels={self.img_channels:d},',
            f'num_fp16_res={self.num_fp16_res:d}'])
hamzapehlivan commented 6 months ago

Hi, Our model modifies the StyleGAN features if there is an editing/inversion request from the user. # Get F space features F_feats, for the original image skips['F_feats'] = self.generator(latents, skips, return_f = True, **self.G_kwargs_val) # Transform F_feats to incoming edited image images = self.generator(latents_edited, skips, **self.G_kwargs_val)

So, if you do not want to change StyleGAN features, the code should look like: skips['F_feats'] = None images = self.generator(latents_edited, skips, **self.G_kwargs_val)

Do you follow this convention?

jimb2834 commented 6 months ago

@hamzapehlivan - Hello,

Ok great this makes sense. I will try it.

Thanks!