openai / guided-diffusion

MIT License
6.11k stars 813 forks source link

Updated "checkpoint" to use "self.use_checkpoint" flag #144

Open gauenk opened 5 months ago

gauenk commented 5 months ago

I want to use the gradient in a loss function, so I do not want to use "checkpoint" since I will use the backward pass twice. This code change simply makes the "use_checkpoint" flag turn off the checkpointing, if requested.

francois-rozet commented 1 month ago

I have the same issue. The checkpoint function does not support several backward passes with retain_graph=True. I cannot disable checkpoint because the Attention module does not use its use_checkpoint flag.

My current solution is to monkey patch checkpoint:

from guided_diffusion import unet

def monkey_checkpoint(func, inputs, params, flag):
    return func(*inputs)

unet.checkpoint = monkey_checkpoint