Currently, the training process pauses noticeably every 10 steps when performing VAE finetuning. This creates an uneven training flow and reduces overall efficiency. We should modify the implementation to make the transition between UNet and VAE training steps seamless.
Current Behavior
if args.finetune_vae and step % args.vae_train_freq == 0:
vae_losses = vae_finetuner.training_step(
batch["latents"].to("cuda"),
batch["original_images"].to("cuda")
)
Training pauses every 10 steps
Clear interruption in training flow
Separate VAE processing creates noticeable delay
Proposed Solution
Integrate VAE training into main training loop:
... existing code ...
with torch.amp.autocast('cuda', dtype=dtype):
# Process UNet step
unet_loss = training_loss(...)
# Process VAE step in parallel if needed
if args.finetune_vae:
vae_loss = vae_finetuner.training_step(...)
# Combine losses if both present
total_loss = unet_loss + (vae_loss if args.finetune_vae else 0)
# Single backward pass
total_loss.backward()
... existing code ...
Expected Benefits
Smoother training progression
Better memory utilization
Potentially faster overall training
More consistent logging and monitoring
Implementation Checklist
[ ] Modify main training loop to include VAE steps
[ ] Update gradient accumulation to handle both models
[ ] Ensure proper memory management for combined training
[ ] Update progress bar to show both UNet and VAE metrics continuously
[ ] Modify logging to track both processes seamlessly
Questions to Consider
Should VAE and UNet share optimizer/scheduler?
How to balance VAE vs UNet loss contributions?
Best way to handle different learning rates?
Notes
This improvement would make the training process more professional and efficient, matching the seamless integration seen in other stable diffusion training implementations.
Description
Currently, the training process pauses noticeably every 10 steps when performing VAE finetuning. This creates an uneven training flow and reduces overall efficiency. We should modify the implementation to make the transition between UNet and VAE training steps seamless.
Current Behavior
Proposed Solution
Expected Benefits
Implementation Checklist
Questions to Consider
Notes
This improvement would make the training process more professional and efficient, matching the seamless integration seen in other stable diffusion training implementations.