yael-vinker / CLIPasso

Other
850 stars 91 forks source link

代码问题 #14

Open zhou431496 opened 2 years ago

zhou431496 commented 2 years ago

class CLIPVisualEncoder(nn.Module): def init(self, clip_model): super().init() self.clip_model = clip_model self.featuremaps = None

    for i in range(12):  # 12 resblocks in VIT visual transformer
        self.clip_model.visual.transformer.resblocks[i].register_forward_hook(
            self.make_hook(i))

def make_hook(self, name):
    def hook(module, input, output):
        if len(output.shape) == 3:
            self.featuremaps[name] = output.permute(
                1, 0, 2)  # LND -> NLD bs, smth, 768
        else:
            self.featuremaps[name] = output

    return hook

这个函数无法获得特征图