crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.21k stars 371 forks source link

Blocking ops in several samplers, usually from sigmas being a GPU tensor #108

Open drhead opened 2 months ago

drhead commented 2 months ago

There are blocking operations in several of the samplers. Looking at DPM-Solver++(2M), for example:

@torch.no_grad()
def sample_dpmpp_2m(model, x, sigmas, extra_args=None, callback=None, disable=None):
    """DPM-Solver++(2M)."""
    extra_args = {} if extra_args is None else extra_args
    s_in = x.new_ones([x.shape[0]])
    sigma_fn = lambda t: t.neg().exp()
    t_fn = lambda sigma: sigma.log().neg()
    old_denoised = None

    for i in trange(len(sigmas) - 1, disable=disable):
        denoised = model(x, sigmas[i] * s_in, **extra_args)
        if callback is not None:
            callback({'x': x, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i], 'denoised': denoised})
        t, t_next = t_fn(sigmas[i]), t_fn(sigmas[i + 1])
        h = t_next - t
        if old_denoised is None or sigmas[i + 1] == 0:
            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised
        else:
            h_last = t - t_fn(sigmas[i - 1])
            r = h_last / h
            denoised_d = (1 + 1 / (2 * r)) * denoised - (1 / (2 * r)) * old_denoised
            x = (sigma_fn(t_next) / sigma_fn(t)) * x - (-h).expm1() * denoised_d
        old_denoised = denoised
    return x

sigmas is generally going to be a GPU tensor. If it is on the GPU, then on the line: if old_denoised is None or sigmas[i + 1] == 0: it will have to be synced to the CPU as part of control flow, which will block Pytorch dispatch until all operations called prior have completed, and as a result there will be a significant gap between every step where the GPU is idle. The actual impact varies depending on hardware, but having the dispatch queue completely unblocked is very beneficial since usually Pytorch can line up several steps of inference in advance and the GPU will then execute them completely uninterrupted.

For this sampler, you could avoid this by putting sigmas on the CPU, or making a copy of it used specifically for control flow. But this breaks other samplers (Heun and Euler at least), because they use sigmas in a way that actually does need to be a GPU tensor.

I think most samplers can have all tensors used for control flow precalculated before the for loop and that would solve the problem. I believe that it would also be preferable to have tensors used as scalars on CPU where possible in general, since that usually results in less kernel launches than doing the same operation off of a GPU tensor.