DataCTE / SDXL-Training-Improvements

Apache License 2.0
39 stars 0 forks source link

Training flow disrupted by VAE finetuning steps - needs seamless integration #2

Open DataCTE opened 6 days ago

DataCTE commented 6 days ago

Description

Currently, the training process pauses noticeably every 10 steps when performing VAE finetuning. This creates an uneven training flow and reduces overall efficiency. We should modify the implementation to make the transition between UNet and VAE training steps seamless.

Current Behavior

if args.finetune_vae and step % args.vae_train_freq == 0:
    vae_losses = vae_finetuner.training_step(
        batch["latents"].to("cuda"),
        batch["original_images"].to("cuda")
    )

Proposed Solution

  1. Integrate VAE training into main training loop:
    ... existing code ...
    with torch.amp.autocast('cuda', dtype=dtype):
    # Process UNet step
    unet_loss = training_loss(...)
    # Process VAE step in parallel if needed
    if args.finetune_vae:
    vae_loss = vae_finetuner.training_step(...)
    # Combine losses if both present
    total_loss = unet_loss + (vae_loss if args.finetune_vae else 0)
    # Single backward pass
    total_loss.backward()
    ... existing code ...

Expected Benefits

  1. Smoother training progression
  2. Better memory utilization
  3. Potentially faster overall training
  4. More consistent logging and monitoring

Implementation Checklist

Questions to Consider

  1. Should VAE and UNet share optimizer/scheduler?
  2. How to balance VAE vs UNet loss contributions?
  3. Best way to handle different learning rates?

Notes

This improvement would make the training process more professional and efficient, matching the seamless integration seen in other stable diffusion training implementations.