Open nemoooooooooo opened 2 months ago
I have trying to add validation image logging to train_flux_deepspeed_controlnet.py but Im running into errors related to incorrect tensors shape. Here is the code:
` def validate(dit, controlnet, vae, t5, clip, args, accelerator, step): logger.info("Running validation... ")
controlnet.eval() validation_results = [] for validation_image_path, validation_prompt in zip(args.validation_hint_paths, args.validation_prompts): validation_image = Image.open(validation_image_path).convert("RGB") validation_image = validation_image.resize((args.data_config.img_size, args.data_config.img_size)) # Convert the image to a tensor and rearrange dimensions to match [batch_size, channels, height, width] validation_image_tensor = torch.tensor(np.array(validation_image)).permute(2, 0, 1).unsqueeze(0).to(accelerator.device).to(to> with torch.no_grad(): x_1 = vae.encode(validation_image_tensor) inp = prepare(t5=t5, clip=clip, img=x_1, prompt=[validation_prompt]) x_1 = rearrange(x_1, "b c (h ph) (w pw) -> b (h w) (c ph pw)", ph=2, pw=2) bs = 1 t = torch.sigmoid(torch.randn((bs,), device=accelerator.device)) x_0 = torch.randn_like(x_1).to(accelerator.device) x_t = (1 - t.unsqueeze(1).unsqueeze(2).repeat(1, x_1.shape[1], x_1.shape[2])) * x_1 + t.unsqueeze(1).unsqueeze(2).repeat(> bsz = x_1.shape[0] guidance_vec = torch.full((x_t.shape[0],), 4, device=x_t.device, dtype=x_t.dtype) block_res_samples = controlnet( img=x_t.to(torch.float32), img_ids=inp['img_ids'].to(torch.float32), controlnet_cond=torch.tensor(np.array(validation_image)).unsqueeze(0).to(accelerator.device).to(torch.float32), txt=inp['txt'].to(torch.float32), txt_ids=inp['txt_ids'].to(torch.float32), y=inp['vec'].to(torch.float32), timesteps=t.to(torch.float32), guidance=guidance_vec.to(torch.float32), ) model_pred = dit( img=x_t.to(torch.float32), img_ids=inp['img_ids'].to(torch.float32), txt=inp['txt'].to(torch.float32), txt_ids=inp['txt_ids'].to(torch.float32), block_controlnet_hidden_states=[ sample.to(dtype=torch.float32) for sample in block_res_samples ], y=inp['vec'].to(torch.float32), timesteps=t.to(torch.float32), guidance=guidance_vec.to(torch.float32), ) # Process the output to generate the final image generated_image = model_pred[0].cpu().numpy() validation_results.append((validation_image, generated_image, validation_prompt)) controlnet.train() if accelerator.is_main_process: logger.info(f"Logging validation images for step {step}") for i, (input_image, output_image, prompt) in enumerate(validation_results): input_image.save(os.path.join(args.output_dir, f"validation_input_{i}_{step}.png")) Image.fromarray(output_image).save(os.path.join(args.output_dir, f"validation_output_{i}_{step}.png")) logger.info("Validation complete.")
`
you can use the xflux_pipeline.py to do validation, which is easy to process data.
I have trying to add validation image logging to train_flux_deepspeed_controlnet.py but Im running into errors related to incorrect tensors shape. Here is the code:
` def validate(dit, controlnet, vae, t5, clip, args, accelerator, step): logger.info("Running validation... ")
`