kuprel / min-dalle

min(DALL·E) is a fast, minimal port of DALL·E Mini to PyTorch
MIT License
3.48k stars 257 forks source link

Cannot register hook? #99

Closed AntonotnaWang closed 1 year ago

AntonotnaWang commented 1 year ago

Dear kuprel,

Thanks for your nice work. I hope to get the forward feature maps and backward gradients of middle hidden layers. However, I do not know why I cannot register hook by adding a hook function to the source code.

I added the following to the original MinDalle class

def hook(self,
        encoder_need_hook = False,
        decoder_need_hook = False,
        detokenizer_need_hook = False,
        target_hook_layers = []):

        self.encoder_need_hook = encoder_need_hook
        self.decoder_need_hook = decoder_need_hook
        self.detokenizer_need_hook = detokenizer_need_hook
        self.target_hook_layers = target_hook_layers
        self.fmap_pool = dict()
        self.grad_pool = dict()
        self.handlers = []

        if self.encoder_need_hook:
            self.hook_func(self.encoder, self.target_hook_layers)

        if self.decoder_need_hook:
            self.hook_func(self.decoder, self.target_hook_layers)

        if self.detokenizer_need_hook:
            self.hook_func(self.detokenizer, self.target_hook_layers)

    def remove_hook(self):
        for handler in self.handlers:
            handler.remove()
        self.fmap_pool = dict()
        self.grad_pool = dict()
        self.handlers = []

    def hook_func(self, cur_model, target_layers):
        def forward_hook(key):
            def forward_hook_(module, input_im, output_im):
                print(str(module), "forward_hook_", input_im[0].shape, output_im.shape)
                self.fmap_pool[key] = input_im[0].detach().clone().cpu().numpy()
                self.fmap_pool[key] = output_im.detach().clone().cpu().numpy()
            return forward_hook_

        def backward_hook(key):
            def backward_hook_(module, grad_in, grad_out):
                print(str(module), "backward_hook_", grad_in[0].shape, grad_out[0].shape)
                self.grad_pool[key] = grad_in[0].detach().clone().cpu().numpy()
                self.grad_pool[key] = grad_out[0].detach().clone().cpu().numpy()
            return backward_hook_

        for name, module in cur_model.named_modules():
            if name in target_layers:
                print("register hook for "+str(name)+" in "+str(type(cur_model).__name__))
                self.handlers.append(module.register_forward_hook(forward_hook(name)))
                self.handlers.append(module.register_backward_hook(backward_hook(name)))

Then, I test the code

dtype = "float32" #@param ["float32", "float16", "bfloat16"]

device = "cuda:0"

model = MinDalle(
    models_root = "../Learn/min-dalle/pretrained",
    dtype=getattr(torch, dtype),
    device=device,
    is_mega=True, 
    is_reusable=True,
)

target_layer_names = ["decoder.up.3"]

for name, module in model.detokenizer.named_modules():
    if name in target_layer_names:
        print(name)

model.hook(
    encoder_need_hook = False,
    decoder_need_hook = False,
    detokenizer_need_hook = True,
    target_hook_layers = target_layer_names)

output = model.detokenizer(False, (torch.rand(1, 256)*1000).long().to(device))

I hope to add hook to detokenizer but it does not work. Sorry to bother you. May I know how to fix the problem? Thanks a lot.

Best

AntonotnaWang commented 1 year ago

Hi, I have solved the problem by simply changing all the .forward(x) to be (x) in the codes of models.