hatchetProject / QuEST

QuEST: Efficient Finetuning for Low-bit Diffusion Models
26 stars 2 forks source link

Discrepancy in the order of conditioning #12

Closed adilhasan927 closed 2 weeks ago

adilhasan927 commented 2 weeks ago

Hello,

In utils.py:

def get_train_samples(args, sample_data, custom_steps=None):
    num_samples, num_st = int(args.cali_n), args.cali_st
...
    if args.cond:
        xs_lst += xs_lst
        ts_lst += ts_lst
        conds_lst = [sample_data["cs"][i][:num_samples] for i in timesteps] + [sample_data["ucs"][i][:num_samples] for i in timesteps]
    xs = torch.cat(xs_lst, dim=0)
    ts = torch.cat(ts_lst, dim=0)
    if args.cond:
        conds = torch.cat(conds_lst, dim=0)
        return xs, ts, conds

In ddim.py:

          x_in = torch.cat([x] * 2)
          t_in = torch.cat([t] * 2)
          if isinstance(c, dict):
              assert isinstance(unconditional_conditioning, dict)
              c_in = dict()
              for k in c:
                  if isinstance(c[k], list):
                      c_in[k] = [
                          torch.cat([unconditional_conditioning[k][i], c[k][i]])
                          for i in range(len(c[k]))
                      ]
                  else:
                      c_in[k] = torch.cat([unconditional_conditioning[k], c[k]])
          else:
              c_in = torch.cat([unconditional_conditioning, c])

Why is it that in utils.py, we put the cs first, but in ddim.py, we put the ucs first?

Thank you for your time

hatchetProject commented 2 weeks ago

Hi, I don't think the above two scripts are related. The second script is for sampling from the quantized/FP model. The first script is for getting calibration data for quantization, and do not go through the sampling process. The order does not need to match.

adilhasan927 commented 2 weeks ago

Ah, thank you, I understand now.

adilhasan927 commented 2 weeks ago

closed