Open gauenk opened 5 months ago
I have the same issue. The checkpoint
function does not support several backward passes with retain_graph=True
. I cannot disable checkpoint
because the Attention
module does not use its use_checkpoint
flag.
My current solution is to monkey patch checkpoint
:
from guided_diffusion import unet
def monkey_checkpoint(func, inputs, params, flag):
return func(*inputs)
unet.checkpoint = monkey_checkpoint
I want to use the gradient in a loss function, so I do not want to use "checkpoint" since I will use the backward pass twice. This code change simply makes the "use_checkpoint" flag turn off the checkpointing, if requested.