crowsonkb / k-diffusion

Karras et al. (2022) diffusion models for PyTorch
MIT License
2.18k stars 369 forks source link

Conditional *image* generation (img2img) #70

Open CallShaul opened 1 year ago

CallShaul commented 1 year ago

Hi,

In order to add support for conditional image generation, in addition to the initial image embedding into unet_cond, (extra_args['unet_cond'] = img_cond) what should I put in extra_args['cross_cond'] and extra_args['cross_cond_padding'] ?

(before the loss calculation in the line: losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args))

@crowsonkb @nekoshadow1 @brycedrennan

Thanks !

drscotthawley commented 4 months ago

I'd be interested in seeing an answer to this as well. e.g. for the simple case of MNIST, how might we implement (or activate) class-conditional generation?

i see class_cond in the code, and a cond_dropout_rate in the config files, so maybe it's already training that way... But the in the output from demo(), it seems to just be random. Perhaps we just need to change line 369 in train.py from this...

            class_cond = torch.randint(0, num_classes, [accelerator.num_processes, n_per_proc], generator=demo_gen).to(device)

To something more "intentional", such as...

            class_cond = torch.remainder(torch.arange(0, accelerator.num_processes*n_per_proc-1), num_classes).reshape([accelerator.num_processes, n_per_proc]).int().to(device)

....?

Update: yep! That worked! :-)

demo_grid_13499_969d27db3303994e126b

CallShaul commented 2 months ago

Solution:

I've made it work, here's the main steps: (some more workarounds are needed to make it run, in the inference as well, but this is the main idea):

  1. get the conditioned image in each batch training iteration:

    unet_cond = get_condition_channels(model_config, img_cond)
    extra_args['unet_cond'] = unet_cond.to(device)
  2. modify the "losses" line calculation, and add the image condition there: losses = model.loss(reals, noise, sigma, aug_cond=aug_cond, **extra_args)