Open zhou431496 opened 2 years ago
class CLIPVisualEncoder(nn.Module): def init(self, clip_model): super().init() self.clip_model = clip_model self.featuremaps = None
for i in range(12): # 12 resblocks in VIT visual transformer self.clip_model.visual.transformer.resblocks[i].register_forward_hook( self.make_hook(i)) def make_hook(self, name): def hook(module, input, output): if len(output.shape) == 3: self.featuremaps[name] = output.permute( 1, 0, 2) # LND -> NLD bs, smth, 768 else: self.featuremaps[name] = output return hook
这个函数无法获得特征图
class CLIPVisualEncoder(nn.Module): def init(self, clip_model): super().init() self.clip_model = clip_model self.featuremaps = None
这个函数无法获得特征图