XLabs-AI / x-flux

Apache License 2.0
1.63k stars 118 forks source link

Validation logs for controlnet training #89

Open nemoooooooooo opened 2 months ago

nemoooooooooo commented 2 months ago

I have trying to add validation image logging to train_flux_deepspeed_controlnet.py but Im running into errors related to incorrect tensors shape. Here is the code:

` def validate(dit, controlnet, vae, t5, clip, args, accelerator, step): logger.info("Running validation... ")

controlnet.eval()

validation_results = []
for validation_image_path, validation_prompt in zip(args.validation_hint_paths, args.validation_prompts):
    validation_image = Image.open(validation_image_path).convert("RGB")
    validation_image = validation_image.resize((args.data_config.img_size, args.data_config.img_size))

    # Convert the image to a tensor and rearrange dimensions to match [batch_size, channels, height, width]
    validation_image_tensor = torch.tensor(np.array(validation_image)).permute(2, 0, 1).unsqueeze(0).to(accelerator.device).to(to>

    with torch.no_grad():
        x_1 = vae.encode(validation_image_tensor)
        inp = prepare(t5=t5, clip=clip, img=x_1, prompt=[validation_prompt])

        x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2)

        bs = 1
        t = torch.sigmoid(torch.randn((bs,), device=accelerator.device))

        x_0 = torch.randn_like(x_1).to(accelerator.device)
        x_t = (1 - t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2])) * x_1 + t.unsqueeze(1).unsqueeze(2).repeat(>
        bsz = x_1.shape[0]
        guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype)

        block_res_samples = controlnet(
            img=x_t.to(torch.float32),
            img_ids=inp['img_ids'].to(torch.float32),
            controlnet_cond=torch.tensor(np.array(validation_image)).unsqueeze(0).to(accelerator.device).to(torch.float32),
            txt=inp['txt'].to(torch.float32),
            txt_ids=inp['txt_ids'].to(torch.float32),
            y=inp['vec'].to(torch.float32),
            timesteps=t.to(torch.float32),
            guidance=guidance_vec.to(torch.float32),
        )

        model_pred = dit(
            img=x_t.to(torch.float32),
            img_ids=inp['img_ids'].to(torch.float32),
            txt=inp['txt'].to(torch.float32),
            txt_ids=inp['txt_ids'].to(torch.float32),
            block_controlnet_hidden_states=[
                sample.to(dtype=torch.float32) for sample in block_res_samples
            ],
            y=inp['vec'].to(torch.float32),
            timesteps=t.to(torch.float32),
            guidance=guidance_vec.to(torch.float32),
        )

        # Process the output to generate the final image
        generated_image = model_pred[0].cpu().numpy()
        validation_results.append((validation_image, generated_image, validation_prompt))

controlnet.train()

if accelerator.is_main_process:
    logger.info(f"Logging validation images for step {step}")

    for i, (input_image, output_image, prompt) in enumerate(validation_results):
        input_image.save(os.path.join(args.output_dir, f"validation_input_{i}_{step}.png"))
        Image.fromarray(output_image).save(os.path.join(args.output_dir, f"validation_output_{i}_{step}.png"))

    logger.info("Validation complete.")

`

DidiD1 commented 2 months ago

you can use the xflux_pipeline.py to do validation, which is easy to process data.