DataCTE / SDXL-Training-Improvements

Apache License 2.0
39 stars 0 forks source link

Validation Image Decoding Fails #3

Open DataCTE opened 6 days ago

DataCTE commented 6 days ago

Description

The validation image generation is currently producing only random noise patterns (see attached example image) instead of proper decoded images. This appears to be a systematic failure in the VAE decoding pipeline, particularly with bfloat16 handling.

image

Current Behavior

Current Implementation

def prepare_image(img):
    with torch.cuda.amp.autocast():
        if img.shape[1] == 4:
            img = self.default_vae.decode(img / 0.18215).sample
        img = img.float()
    return img

Root Cause Analysis

  1. VAE Initialization:

    self.default_vae = AutoencoderKL.from_pretrained(
    "stabilityai/stable-diffusion-xl-base-1.0", 
    subfolder="vae",
    torch_dtype=torch.bfloat16
    ).to(device)
  2. Potential Issues:

    • VAE weights not properly loaded in bfloat16
    • Decoder normalization failing
    • Latent scaling factor incorrect
    • Memory corruption during decode

Proposed Fix

class ModelValidator:
    def __init__(self, ...):
        # Load VAE in float32 first
        self.default_vae = AutoencoderKL.from_pretrained(
            "stabilityai/stable-diffusion-xl-base-1.0",
            subfolder="vae"
        ).to("cuda")

        # Explicit dtype handling for decode
        def prepare_image(img):
            with torch.no_grad():
                # Convert to float32 for decode
                if img.dtype == torch.bfloat16:
                    img = img.float()

                # Proper scaling and decode
                if img.shape[1] == 4:
                    img = img / 0.18215
                    img = self.default_vae.decode(img).sample

                # Ensure valid image range
                img = torch.clamp(img, -1, 1)

            return img

Validation Steps

  1. Add tensor validation:

    def validate_tensor(tensor, stage=""):
    print(f"[{stage}] Shape: {tensor.shape}, "
          f"dtype: {tensor.dtype}, "
          f"range: ({tensor.min():.3f}, {tensor.max():.3f}), "
          f"mean: {tensor.mean():.3f}")
  2. Add checkpoints in decode pipeline:

    # Before decode
    validate_tensor(img, "pre-decode")
    # After decode
    validate_tensor(decoded, "post-decode")

Testing Plan

  1. Generate validation images with float32 VAE
  2. Compare with bfloat16 results
  3. Add tensor validation logging
  4. Test with small batch of known-good latents

Impact

jetjodh commented 2 days ago

Have you considered using ollin's fixed vae or tiny vae implementation for validation loop?