huggingface / diffusers

🤗 Diffusers: State-of-the-art diffusion models for image and audio generation in PyTorch and FLAX.
https://huggingface.co/docs/diffusers
Apache License 2.0
26.32k stars 5.42k forks source link

VAE training sample script #3726

Open zhuliyi0 opened 1 year ago

zhuliyi0 commented 1 year ago

I believe the current lack of easy access to VAE training is stopping diffusion models from disrupting even more industries.

I'm talking about consistent details on things that are less represented in the original training data. 64x64 res can only carry so much detail. Very often I get good result from latent space (by checking the low-res intermedia image) before the final image is ruined by bad details. No prompting or finetuning or controlnet could solve this issue, I tried, and l know lots of other people tried, and most of them are trying without realising that the problem cannot be solved unless the thing that produces the final details can be trained with their domain data.

Right now VAE cannot be easily trained, at least not by someone like me who is not very good at math and python, so there is definitly a demand here. May I hope there will be a sample script based on diffusors to start with? I tried mess with the ones in compvis repo but to no avail. Thanks in advance!

zhuliyi0 commented 1 year ago

I left some comments in the PR thread suggesting some code change.

I ran several full unet-TE finetunning sessions with VAEs trained with various LR. So far I need to set LR for VAE training to be very low (1e-7) for the VAE to be barely workable with unet/TE, otherwise there will be color blotches in the output image, and it gets worse as finetuning goes on. Also, the workflow of full training a VAE then finetuning unet-TE with it feels combersome, slow and too many hyperparameters to try. Hopefully decoder-only training will do without so much hassle.

aandyw commented 1 year ago

@zhuliyi0 @epitaque Thank you guys for the suggestions. For now, I've committed the torch.no_grad() and I'll look over the rest of the changes more thoroughly once I get some time. Currently occupied with exams 😬

The decoder and encoder stuff should just be a simple change but I wanna make sure the things being committed this time are a little more tidy.

ThibaultCastells commented 1 year ago

Hey! I also didn't have time this week, due to a paper submission deadline. I'll get back to this in a week.

yeonsikch commented 1 year ago

Hi guys. Thanks for your train_vae.py script!

I wanna train my dataset for vae. but I faced some bug.

  1. FP16 (optimizer)
  2. Multi-gpu (Not working train_dataloader after some steps) (GPU : (1~8)*A100)

So, I'll fixing this issues from your script. If I can fix some issues, could I send PR your scrip code?

Thnaks.

aandyw commented 1 year ago

Hi guys. Thanks for your train_vae.py script!

I wanna train my dataset for vae. but I faced some bug.

  1. FP16 (optimizer)
  2. Multi-gpu (Not working train_dataloader after some steps) (GPU : (1~8)*A100)

So, I'll fixing this issues from your script. If I can fix some issues, could I send PR your scrip code?

Thnaks.

Yeah, if you could fix those issues that would be awesome! And of course, any PRs are welcome. I just haven't had the time to work on this issue lately due school + exams. But things should be wrapping up soon...

yeonsikch commented 1 year ago

I just resolve this issue (FP16 (optimizer)). I'll send PR this issue and some fixed code tmrr.

aleksmirosh commented 1 year ago

could you please show how you fixed the issue?

aleksmirosh commented 1 year ago

why you do not use the Unet structure for training?

yeonsikch commented 1 year ago

I fixed some issues.

  1. fp16
  2. multi-gpu
  3. ema
  4. train script
  5. huggingface upload

But, I have to fix about Loss Function. this loss function is not right.

if anyone need to fix code about fp16 issue. try this:

vae, vae.encoder, vae.decoder = accelerator.prepare(vae, vae.encoder, vae.decoder)
yeonsikch commented 1 year ago

why you do not use the Unet structure for training?

We have to train step by step.

  1. VAE(u can load other's checkpoint weight)
  2. CLIP(u can load other's checkpoint weight)
  3. UNet(u can load other's checkpoint weight)
ThibaultCastells commented 1 year ago

@yeonsikch thank you so much, you saved me a lot of time!

But, I have to fix about Loss Function. this loss function is not right.

Could you tell me more about this?

@aleksmirosh

why you do not use the Unet structure for training?

There are multiple reasons for that, but in short: it reduces the training complexity, and avoid memory issues (as both vae and UNet already require a lot of space in memory, and the average user doesn’t have a 10k$ GPU).

yeonsikch commented 1 year ago

@yeonsikch thank you so much, you saved me a lot of time!

But, I have to fix about Loss Function. this loss function is not right.

Could you tell me more about this?

Hi! you are in Seoul! nice to meet u. I'm in Seoul too.

I also saw Stability's fine-tuning way. They said they used loss term with (MSE + 0.1*LPIPS) after (L1 + LPIPS). And, I think this LPIPS is LPIPSWithDiscriminator. (Ref. under url plz)

The LPIPSWithDiscriminator contain Discriminator Model to calculate useful loss. I'm not sure but, as I know just LPIPS is not contained Discriminator Model, isn't right?

So, I think we have to train Discriminator model and use this to calculate loss like original paper code. ldm/modules/losses/contperceptual.py

If my opinion isn't right, plz any adveise to me .

Thanks!

ThibaultCastells commented 1 year ago

Okay, thanks for the explanation! Honestly, I don't think I know more than you do regarding the loss, I figured things out by reading the code and some discussions here and there on internet, and by experimenting. For me the current method seems to work, but it's worth trying other solutions and compare to see what give the best results, if you have time for that 🙂

By the way, I suggest using a human dataset for the tests, if you can find one, because your brain is really good at detecting anomalies in human face so it will be easier to visually compare.

bghira commented 1 year ago

i would try min-SNR weighted loss!

yeonsikch commented 1 year ago

Okay, thanks for the explanation! Honestly, I don't think I know more than you do regarding the loss, I figured things out by reading the code and some discussions here and there on internet, and by experimenting. For me the current method seems to work, but it's worth trying other solutions and compare to see what give the best results, if you have time for that 🙂

By the way, I suggest using a human dataset for the tests, if you can find one, because your brain is really good at detecting anomalies in human face so it will be easier to visually compare.

Thank you for your advise! 😊

One more suggestion, I think we have to use z = posterior.sample() instead of : z = posterior.mode() in train script.

OwalnutO commented 1 year ago

@yeonsikch thank you so much, you saved me a lot of time!

But, I have to fix about Loss Function. this loss function is not right.

Could you tell me more about this?

Hi! you are in Seoul! nice to meet u. I'm in Seoul too.

I also saw Stability's fine-tuning way. They said they used loss term with (MSE + 0.1*LPIPS) after (L1 + LPIPS). And, I think this LPIPS is LPIPSWithDiscriminator. (Ref. under url plz)

The LPIPSWithDiscriminator contain Discriminator Model to calculate useful loss. I'm not sure but, as I know just LPIPS is not contained Discriminator Model, isn't right?

So, I think we have to train Discriminator model and use this to calculate loss like original paper code. ldm/modules/losses/contperceptual.py

If my opinion isn't right, plz any adveise to me .

Thanks!

I think the loss with discriminator is needed. But I'm not sure which script is the latest script that resolves the multi-gpu and fp16 bug?

yeonsikch commented 1 year ago

@yeonsikch thank you so much, you saved me a lot of time!

But, I have to fix about Loss Function. this loss function is not right.

Could you tell me more about this?

Hi! you are in Seoul! nice to meet u. I'm in Seoul too. I also saw Stability's fine-tuning way. They said they used loss term with (MSE + 0.1*LPIPS) after (L1 + LPIPS). And, I think this LPIPS is LPIPSWithDiscriminator. (Ref. under url plz) The LPIPSWithDiscriminator contain Discriminator Model to calculate useful loss. I'm not sure but, as I know just LPIPS is not contained Discriminator Model, isn't right? So, I think we have to train Discriminator model and use this to calculate loss like original paper code. ldm/modules/losses/contperceptual.py If my opinion isn't right, plz any adveise to me . Thanks!

I think the loss with discriminator is needed. So where is the latest script?

I didn’t write script. But, u can try latent diffusion official repo. There is the vae train code in that repo.

OwalnutO commented 1 year ago

@yeonsikch thank you so much, you saved me a lot of time!

But, I have to fix about Loss Function. this loss function is not right.

Could you tell me more about this?

Hi! you are in Seoul! nice to meet u. I'm in Seoul too. I also saw Stability's fine-tuning way. They said they used loss term with (MSE + 0.1*LPIPS) after (L1 + LPIPS). And, I think this LPIPS is LPIPSWithDiscriminator. (Ref. under url plz) The LPIPSWithDiscriminator contain Discriminator Model to calculate useful loss. I'm not sure but, as I know just LPIPS is not contained Discriminator Model, isn't right? So, I think we have to train Discriminator model and use this to calculate loss like original paper code. ldm/modules/losses/contperceptual.py If my opinion isn't right, plz any adveise to me . Thanks!

I think the loss with discriminator is needed. So where is the latest script?

I didn’t write script. But, u can try latent diffusion official repo. There is the vae train code in that repo.

I mean I have tried vae, vae.encoder, vae.decoder = accelerator.prepare(vae, vae.encoder, vae.decoder), but I still fail to use neither fp16 nor multi-gpu. Could you provide some suggestions?

yeonsikch commented 1 year ago

If u wanna multi-gpu, u have to use vae.module.encode(batch~~) instead of vae.encode(batch~~) And, for fp16

Vae, vae.encode, vae.decode, optimizer … etc = accelerator.prepare(vae, vae.encode, vae.decode, optimizer, … etc)

Im in a car now. So, I can’t see my latest script now. Sorry. But, if u show me some script and error message, I’ll help you.

Thanks.

aleksmirosh commented 1 year ago

sorry, for the question not related to trade, it just looks like people here understand the pain of training the Diffusion model for not the RGB image. Does anybody find the best normalization for data? Mean-std looks preferable, but data goes beyond [0, 1], is min-max the better option?

yeonsikch commented 1 year ago

sorry, for the question not related to trade, it just looks like people here understand the pain of training the Diffusion model for not the RGB image. Does anybody find the best normalization for data? Mean-std looks preferable, but data goes beyond [0, 1], is min-max the better option?

I recommend this:

transforms.Compose([
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
])

Because, as I know, almost StableDiffusion-related models (VAE, UNET, CLIP = Stable Diffusion) trained by this normalization way.

If u use other norm ways, u should re-scale(re-norm) when passing through the different models of stable diffusion.

bghira commented 1 year ago

see #4636 for what happens with that normalisation range :D

yeonsikch commented 1 year ago

I compared this vae train script and latent-diffusion original code.

As a result, I'm sure that we need to fix our loss term. In my experience, latent-diffusion original code is better.

I recommend using latent-diffusion original code. If u wanna use latent-diffusion original code, u should write ur custom dataset class.

aleksmirosh commented 1 year ago

I compared this vae train script and latent-diffusion original code.

As a result, I'm sure that we need to fix our loss term. In my experience, latent-diffusion original code is better.

I recommend using latent-diffusion original code. If u wanna use latent-diffusion original code, u should write ur custom dataset class.

are they use the same Autoencoder model? why do you think it is better? i tried both, did not get result with any

aandyw commented 1 year ago

I fixed some issues.

  1. fp16
  2. multi-gpu
  3. ema
  4. train script
  5. huggingface upload

But, I have to fix about Loss Function. this loss function is not right.

if anyone need to fix code about fp16 issue. try this:

vae, vae.encoder, vae.decoder = accelerator.prepare(vae, vae.encoder, vae.decoder)

Can you make a commit for these changes? I'll work on the VAE loss and hopefully try to get it matching with LDM.

FrsECM commented 1 year ago

Hi Everyone ! Thanks a lot for your thread ! It helped me a lot. I've implemented my own version of the script based on your great work.

I have a question regarding the decoder training. In my mind, it was necessary to sample in the encoder output distribution :

src : https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

But in your implementation you used directly the "mode" to feed the decoder :

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample

I tried to challenge this assumption and i performed 2 trainings with 100k iterations on CelebA-HQ.

In my different tries, i noticed that mode seems to render better images, but i don't know what is the behavior when we'll train a LDM on top of that. To give uncertainty in the latentspace can definitely help the further diffusion process.

image

Does somebody have any elements about that ? Thanks !

ThibaultCastells commented 1 year ago

Hello @FrsECM ! Thank you for sharing your results and this interesting Medium post! Could you share the implementation you tried?

i don't know what is the behavior when we'll train a LDM on top of that. To give uncertainty in the latentspace can definitely help the further diffusion process.

It's hard to tell without trying, but I think we also need to keep in mind that the Stable Diffusion performance is bounded by the VAE performance: if the VAE can only generate blurry images then Stable Diffusion will produce blurry images, no matter how well the unet is trained.

trouble-maker007 commented 1 year ago

@FrsECM Would you like you share your vae training implementation

FrsECM commented 1 year ago

Hi @trouble-maker007 @ThibaultCastells , I did a try to train a LDM model based on both VAE, the one with sampling, the one without. Bellow the result for the same amount of iterations :

image

image

For me it confirms that it's better to train with uncertainty. Anyway, the issues i face to make the VAE converge remains on generated images.

You can find my training script there : https://github.com/FrsECM/diffusers/blob/add-semantic-diffusion/examples/community/semantic_image_synthesis/train_vae_ldm.py

bigcornflake commented 9 months ago

Hi Everyone ! Thanks a lot for your thread ! It helped me a lot. I've implemented my own version of the script based on your great work.

I have a question regarding the decoder training. In my mind, it was necessary to sample in the encoder output distribution :

src : https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

But in your implementation you used directly the "mode" to feed the decoder :

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample

I tried to challenge this assumption and i performed 2 trainings with 100k iterations on CelebA-HQ.

In my different tries, i noticed that mode seems to render better images, but i don't know what is the behavior when we'll train a LDM on top of that. To give uncertainty in the latentspace can definitely help the further diffusion process.

image

Does somebody have any elements about that ? Thanks !

Hi Everyone ! Thanks a lot for your thread ! It helped me a lot. I've implemented my own version of the script based on your great work.

I have a question regarding the decoder training. In my mind, it was necessary to sample in the encoder output distribution :

src : https://towardsdatascience.com/understanding-variational-autoencoders-vaes-f70510919f73

But in your implementation you used directly the "mode" to feed the decoder :

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        for step, batch in enumerate(train_dataloader):
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                posterior = vae.encode(target).latent_dist
                z = posterior.mode()
                pred = vae.decode(z).sample

I tried to challenge this assumption and i performed 2 trainings with 100k iterations on CelebA-HQ.

In my different tries, i noticed that mode seems to render better images, but i don't know what is the behavior when we'll train a LDM on top of that. To give uncertainty in the latentspace can definitely help the further diffusion process.

image

Does somebody have any elements about that ? Thanks !

I think your method is correct. Although using posterior.mode() looks better, it actually abandons randomness.

sapkun commented 8 months ago

@yeonsikch

Hello, I have tried what you said for fp16 issue, but i still got an error :

    (
        vae,
        vae.encode,
        vae.decode,
        optimizer,
        train_dataloader,
        test_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        vae, vae.encode, vae.decode, optimizer, train_dataloader, test_dataloader, lr_scheduler
    )

    Traceback (most recent call last):
  File "train_vae.py", line 551, in <module>
    main()
  File "train_vae.py", line 509, in main
    optimizer.step()
  File "/abdB5045/sd_model/lib/python3.8/site-packages/accelerate/optimizer.py", line 132, in step
    self.scaler.step(self.optimizer, closure)
  File "/abdB5045/sd_model/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 446, in step
    self.unscale_(optimizer)
  File "/abdB5045/sd_model/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 336, in unscale_
    optimizer_state["found_inf_per_device"] = self._unscale_grads_(
  File "/abdB5045/sd_model/lib/python3.8/site-packages/torch/cuda/amp/grad_scaler.py", line 258, in _unscale_grads_
    raise ValueError("Attempting to unscale FP16 gradients.")
ValueError: Attempting to unscale FP16 gradients.

can you share your full script? thanks!

- `Accelerate` version: 0.27.0
- Platform: Linux-3.10.0-1160.el7.x86_64-x86_64-with-glibc2.10
- Python version: 3.8.12
- Numpy version: 1.24.4
- PyTorch version (GPU?): 2.2.0+cu121 (True)
- PyTorch XPU available: False
- PyTorch NPU available: False
- System RAM: 1007.35 GB
- GPU type: NVIDIA A100-PCIE-40GB
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: NO
        - mixed_precision: fp16
        - use_cpu: False
        - debug: False
        - num_processes: 1
        - machine_rank: 0
        - num_machines: 1
        - gpu_ids: 0
        - rdzv_backend: static
        - same_network: True
        - main_training_function: main
        - downcast_bf16: no
        - tpu_use_cluster: False
        - tpu_use_sudo: False
        - tpu_env: []
yeonsikch commented 8 months ago

@sapkun I've been trying to learn with the diffusers VAE code here, but the problem is that don't have lpips as a loss function. Of course, the loss term is something you can customize, but if you want to replicate the original, use the latent-diffusion code. I was able to finish learning VAE using that code.

However, that VAE code isn't perfect too. I've modified it to train with my own data and will share it again once it's up on git.

sapkun commented 8 months ago

thanks for your reply, my question is you mentioned that you used the accelerator.prepare method to prepare various components (vae,vae.encode,vae.decode,optimizer,train_dataloader,test_dataloader, lr_scheduler, ) = accelerator.prepare(vae, vae.encode, vae.decode, optimizer, train_dataloader, test_dataloader, lr_scheduler) but i was unable to resolve the issue related to mixed-precision training (fp16) and encountered difficulties in running the code across multiple GPUs. this is my question is. FYI, I used MSE + 0.1 * LPIPS term to obtain good result, i removed kl term from the loss.

yeonsikch commented 8 months ago

@sapkun It's been a while since I've used that code, so it took me a while to find it. Post the full code here.

"""
TODO: fix training mixed precision -- issue with AdamW optimizer
"""

import argparse
import logging
import math
import os
from pathlib import Path

import accelerate
import numpy as np
import torch
import torch.nn.functional as F
import torch.utils.checkpoint
import torchvision
from accelerate import Accelerator
from accelerate.logging import get_logger
from accelerate.utils import ProjectConfiguration, set_seed
from datasets import load_dataset
from huggingface_hub import create_repo, upload_folder
from packaging import version
from torchvision import transforms
from tqdm.auto import tqdm

import diffusers
from diffusers import AutoencoderKL
from diffusers.optimization import get_scheduler
from diffusers.training_utils import EMAModel
from diffusers.utils import is_wandb_available

import lpips
from PIL import Image

if is_wandb_available():
    import wandb

logger = get_logger(__name__, log_level="INFO")

@torch.no_grad()
def log_validation(args, repo_id, test_dataloader, vae, accelerator, weight_dtype, epoch):
    logger.info("Running validation... ")

    vae_model = accelerator.unwrap_model(vae)
    images = []

    for _, sample in enumerate(test_dataloader):
        x = sample["pixel_values"].to(weight_dtype)
        reconstructions = vae_model(x).sample
        images.append(
            torch.cat([sample["pixel_values"].cpu(), reconstructions.cpu()], axis=0)
        )

    for tracker in accelerator.trackers:
        if tracker.name == "tensorboard":
            np_images = np.stack([np.asarray(img) for img in images])
            tracker.writer.add_images(
                "Original (left), Reconstruction (right)", np_images, epoch
            )
        elif tracker.name == "wandb":
            tracker.log(
                {
                    "Original (left), Reconstruction (right)": [
                        wandb.Image(torchvision.utils.make_grid(image))
                        for _, image in enumerate(images)
                    ]
                }
            )
        else:
            logger.warn(f"image logging not implemented for {tracker.gen_images}")

    if args.push_to_hub:
        try:
            save_model_card(args, repo_id, images, repo_folder=args.output_dir)
            upload_folder(
                repo_id=repo_id,
                folder_path=args.output_dir,
                commit_message="End of training",
                ignore_patterns=["step_*", "epoch_*"],
            )
        except:
            logger.info(f"UserWarning: Your huggingface's memory is limited. The weights will be saved only local path : {args.output_dir}")

    del vae_model
    torch.cuda.empty_cache()

def make_image_grid(imgs, rows, cols):

    w, h = imgs[0].size
    grid = Image.new("RGB", size=(cols * w, rows * h))

    for i, img in enumerate(imgs):
        grid.paste(img, box=(i % cols * w, i // cols * h))
    return grid

def save_model_card(
    args,
    repo_id: str,
    images=None,
    repo_folder=None,
):
    # img_str = ""
    # if len(images) > 0:
    #     image_grid = make_image_grid(images, 1, "example")
    #     image_grid.save(os.path.join(repo_folder, "val_imgs_grid.png"))
    #     img_str += "![val_imgs_grid](./val_imgs_grid.png)\n"

    yaml = f"""
---
license: creativeml-openrail-m
base_model: {args.pretrained_model_name_or_path}
datasets:
- {args.dataset_name}
tags:
- stable-diffusion
- stable-diffusion-diffusers
- text-to-image
- diffusers
inference: true
---
    """
    model_card = f"""
# Text-to-image finetuning - {repo_id}

This pipeline was finetuned from **{args.pretrained_model_name_or_path}** on the **{args.dataset_name}** dataset. Below are some example images generated with the finetuned pipeline using the following prompts: Nothing: \n

## Training info

These are the key hyperparameters used during training:

* Epochs: {args.num_train_epochs}
* Learning rate: {args.learning_rate}
* Batch size: {args.train_batch_size}
* Gradient accumulation steps: {args.gradient_accumulation_steps}
* Image resolution: {args.resolution}
* Mixed-precision: {args.mixed_precision}

"""
    wandb_info = ""
    if is_wandb_available():
        wandb_run_url = None
        if wandb.run is not None:
            wandb_run_url = wandb.run.url

    if wandb_run_url is not None:
        wandb_info = f"""
More information on all the CLI arguments and the environment are available on your [`wandb` run page]({wandb_run_url}).
"""

    model_card += wandb_info

    with open(os.path.join(repo_folder, "README.md"), "w") as f:
        f.write(yaml + model_card)

def parse_args():
    parser = argparse.ArgumentParser(
        description="Simple example of a VAE training script."
    )
    parser.add_argument(
        "--pretrained_model_name_or_path",
        type=str,
        default=None,
        required=False,
        help="Path to pretrained model or model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--revision",
        type=str,
        default=None,
        required=False,
        help="Revision of pretrained model identifier from huggingface.co/models.",
    )
    parser.add_argument(
        "--dataset_name",
        type=str,
        default=None,
        help=(
            "The name of the Dataset (from the HuggingFace hub) to train on (could be your own, possibly private,"
            " dataset). It can also be a path pointing to a local copy of a dataset in your filesystem,"
            " or to a folder containing files that 🤗 Datasets can understand."
        ),
    )
    parser.add_argument(
        "--dataset_config_name",
        type=str,
        default=None,
        help="The config of the Dataset, leave as None if there's only one config.",
    )
    parser.add_argument(
        "--train_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the training data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--test_data_dir",
        type=str,
        default=None,
        help=(
            "A folder containing the validation data. Folder contents must follow the structure described in"
            " https://huggingface.co/docs/datasets/image_dataset#imagefolder. In particular, a `metadata.jsonl` file"
            " must exist to provide the captions for the images. Ignored if `dataset_name` is specified."
        ),
    )
    parser.add_argument(
        "--image_column",
        type=str,
        default="image",
        help="The column of the dataset containing an image.",
    )
    parser.add_argument(
        "--output_dir",
        type=str,
        default="outputs",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--huggingface_repo",
        type=str,
        default="vae-model-finetuned",
        help="The output directory where the model predictions and checkpoints will be written.",
    )
    parser.add_argument(
        "--cache_dir",
        type=str,
        default=None,
        help="The directory where the downloaded models and datasets will be stored.",
    )
    parser.add_argument(
        "--seed", type=int, default=21, help="A seed for reproducible training."
    )
    parser.add_argument(
        "--resolution",
        type=int,
        default=512,#512,
        help=(
            "The resolution for input images, all the images in the train/validation dataset will be resized to this"
            " resolution"
        ),
    )
    parser.add_argument(
        "--train_batch_size",
        type=int,
        default=1,
        help="Batch size (per device) for the training dataloader.",
    )
    parser.add_argument("--num_train_epochs", type=int, default=100)
    parser.add_argument(
        "--gradient_accumulation_steps",
        type=int,
        default=2,
        help="Number of updates steps to accumulate before performing a backward/update pass.",
    )
    parser.add_argument(
        "--gradient_checkpointing",
        action="store_true",
        help="Whether or not to use gradient checkpointing to save memory at the expense of slower backward pass.",
    )
    parser.add_argument(
        "--learning_rate",
        type=float,
        default=1.5e-7, # Reference : Waifu-diffusion-v1-4 config
        # default=4.5e-8,
        help="Initial learning rate (after the potential warmup period) to use.",
    )
    parser.add_argument(
        "--scale_lr",
        action="store_true",
        default=True,
        help="Scale the learning rate by the number of GPUs, gradient accumulation steps, and batch size.",
    )
    parser.add_argument(
        "--lr_scheduler",
        type=str,
        default="constant",
        help=(
            'The scheduler type to use. Choose between ["linear", "cosine", "cosine_with_restarts", "polynomial",'
            ' "constant", "constant_with_warmup"]'
        ),
    )
    parser.add_argument(
        "--lr_warmup_steps",
        type=int,
        default=500,
        help="Number of steps for the warmup in the lr scheduler.",
    )
    parser.add_argument(
        "--logging_dir",
        type=str,
        default="logs",
        help=(
            "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to"
            " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***."
        ),
    )
    parser.add_argument(
        "--mixed_precision",
        type=str,
        default=None,
        choices=["no", "fp16", "bf16"],
        help=(
            "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >="
            " 1.10.and an Nvidia Ampere GPU.  Default to the value of accelerate config of the current system or the"
            " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config."
        ),
    )
    parser.add_argument(
        "--report_to",
        type=str,
        default="tensorboard",
        help=(
            'The integration to report the results and logs to. Supported platforms are `"tensorboard"`'
            ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.'
        ),
    )
    parser.add_argument(
        "--checkpointing_steps",
        type=int,
        default=5000,
        help=(
            "Save a checkpoint of the training state every X updates. These checkpoints are only suitable for resuming"
            " training using `--resume_from_checkpoint`."
        ),
    )
    parser.add_argument(
        "--checkpoints_total_limit",
        type=int,
        default=None,
        help=(
            "Max number of checkpoints to store. Passed as `total_limit` to the `Accelerator` `ProjectConfiguration`."
            " See Accelerator::save_state https://huggingface.co/docs/accelerate/package_reference/accelerator#accelerate.Accelerator.save_state"
            " for more docs"
        ),
    )
    parser.add_argument(
        "--resume_from_checkpoint",
        type=str,
        default=None,
        help=(
            "Whether training should be resumed from a previous checkpoint. Use a path saved by"
            ' `--checkpointing_steps`, or `"latest"` to automatically select the last available checkpoint.'
        ),
    )
    parser.add_argument(
        "--test_samples",
        type=int,
        default=20,
        help="Number of images to remove from training set to be used as validation.",
    )
    parser.add_argument(
        "--validation_epochs",
        type=int,
        default=5,
        help="Run validation every X epochs.",
    )
    parser.add_argument(
        "--tracker_project_name",
        type=str,
        default="vae-fine-tune",
        help=(
            "The `project_name` argument passed to Accelerator.init_trackers for"
            " more information see https://huggingface.co/docs/accelerate/v0.17.0/en/package_reference/accelerator#accelerate.Accelerator"
        ),
    )
    parser.add_argument(
        "--use_8bit_adam", action="store_true", help="Whether or not to use 8-bit Adam from bitsandbytes."
    )
    parser.add_argument("--use_ema", action="store_true", help="Whether to use EMA model.")
    parser.add_argument(
        "--kl_scale",
        type=float,
        default=1e-6,
        help="Scaling factor for the Kullback-Leibler divergence penalty term.",
    )
    parser.add_argument(
        "--lpips_scale",
        type=float,
        default=5e-1,
        help="Scaling factor for the LPIPS metric",
    )
    parser.add_argument(
        "--lpips_start",
        type=int,
        default=50001,
        help="Start for the LPIPS metric",
    )
    parser.add_argument("--adam_beta1", type=float, default=0.9, help="The beta1 parameter for the Adam optimizer.")
    parser.add_argument("--adam_beta2", type=float, default=0.999, help="The beta2 parameter for the Adam optimizer.")
    parser.add_argument("--adam_weight_decay", type=float, default=1e-2, help="Weight decay to use.")
    parser.add_argument("--adam_epsilon", type=float, default=1e-08, help="Epsilon value for the Adam optimizer")
    parser.add_argument("--max_grad_norm", default=1.0, type=float, help="Max gradient norm.")
    parser.add_argument("--push_to_hub", action="store_true", help="Whether or not to push the model to the Hub.")
    parser.add_argument("--hub_token", type=str, default=None, help="The token to use to push to the Model Hub.")

    args = parser.parse_args()

    # args.mixed_precision='fp16'
    # args.pretrained_model_name_or_path="runwayml/stable-diffusion-v1-5"
    # args.dataset_name="yeonsikc/sample_repeat"
    # args.seed=21
    # args.train_batch_size=1
    # args.num_train_epochs=100
    # args.learning_rate=1e-07
    # args.output_dir="/app/output_vae"
    # args.report_to='wandb'
    # args.push_to_hub=True
    # args.validation_epochs=1
    # args.resolution=128
    # args.use_8bit_adam=False

    # Sanity checks
    if args.dataset_name is None and args.train_data_dir is None:
        raise ValueError("Need either a dataset name or a training folder.")

    return args

# train_transforms = transforms.Compose(
#     [
#         transforms.Resize(
#             (128,128), interpolation=transforms.InterpolationMode.BILINEAR
#         ),
#         # transforms.RandomCrop(128),
#         transforms.ToTensor(),
#         transforms.Normalize([0.5], [0.5]),
#     ]
# )

# def preprocess(examples):
#     images = [image.convert("RGB") for image in examples["image"]]
#     examples["pixel_values"] = [train_transforms(image) for image in images]
#     return examples

# def collate_fn(examples):
#     pixel_values = torch.stack([example["pixel_values"] for example in examples])
#     pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
#     return {"pixel_values": pixel_values}

def main():
    args = parse_args()

    logging_dir = os.path.join(args.output_dir, args.logging_dir)

    accelerator_project_config = ProjectConfiguration(
        total_limit=args.checkpoints_total_limit,
        project_dir=args.output_dir,
        logging_dir=logging_dir,
    )

    accelerator = Accelerator(
        gradient_accumulation_steps=args.gradient_accumulation_steps,
        mixed_precision=args.mixed_precision,
        log_with=args.report_to,
        project_config=accelerator_project_config,
    )

    # Make one log on every process with the configuration for debugging.
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO,
    )
    logger.info(accelerator.state, main_process_only=False)

    if args.seed is not None:
        set_seed(args.seed)

    if accelerator.is_main_process:
        if args.output_dir is not None:
            os.makedirs(args.output_dir, exist_ok=True)

        if args.push_to_hub:
            repo_id = create_repo(
                repo_id = Path(args.huggingface_repo).name, exist_ok=True, token=args.hub_token
            ).repo_id

    # Load vae
    try:
        vae = AutoencoderKL.from_pretrained(
            args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision, weight_dtype=torch.float32
        )
    except:
        vae = AutoencoderKL.from_pretrained(
            args.pretrained_model_name_or_path, revision=args.revision, weight_dtype=torch.float32
        )
    if args.use_ema:
        try:
            ema_vae = AutoencoderKL.from_pretrained(
                args.pretrained_model_name_or_path, subfolder="vae", revision=args.revision)
        except:
            ema_vae = AutoencoderKL.from_pretrained(
                args.pretrained_model_name_or_path, revision=args.revision, weight_dtype=torch.float32)
        ema_vae = EMAModel(ema_vae.parameters(), model_cls=AutoencoderKL, model_config=ema_vae.config)

    vae.requires_grad_(True)
    vae_params = vae.parameters()

    # `accelerate` 0.16.0 will have better support for customized saving
    if version.parse(accelerate.__version__) >= version.parse("0.16.0"):
        # create custom saving & loading hooks so that `accelerator.save_state(...)` serializes in a nice format
        def save_model_hook(vae, weights, output_dir):
            if args.use_ema:
                ema_vae.save_pretrained(os.path.join(output_dir, "vae_ema"))

            logger.info(f"{vae = }")
            vae = vae[0]
            vae.save_pretrained(os.path.join(output_dir, "vae"))

        def load_model_hook(vae, input_dir):
            if args.use_ema:
                load_model = EMAModel.from_pretrained(os.path.join(input_dir, "vae_ema"), AutoencoderKL)
                ema_vae.load_state_dict(load_model.state_dict())
                ema_vae.to(accelerator.device)
                del load_model

            # load diffusers style into model
            load_model = AutoencoderKL.from_pretrained(input_dir, subfolder="vae")
            vae.register_to_config(**load_model.config)

            vae.load_state_dict(load_model.state_dict())
            del load_model

        accelerator.register_save_state_pre_hook(save_model_hook)
        accelerator.register_load_state_pre_hook(load_model_hook)

    if args.gradient_checkpointing:
        vae.enable_gradient_checkpointing()

    if args.scale_lr:
        args.learning_rate = (
            args.learning_rate
            * args.gradient_accumulation_steps
            * args.train_batch_size
            * accelerator.num_processes
        )

    # Initialize the optimizer
    if args.use_8bit_adam:
        try:
            import bitsandbytes as bnb
        except ImportError:
            raise ImportError(
                "Please install bitsandbytes to use 8-bit Adam. You can do so by running `pip install bitsandbytes` or `pip install bitsandbytes-windows` for Windows"
            )

        optimizer_cls = bnb.optim.AdamW8bit
    else:
        optimizer_cls = torch.optim.AdamW

    optimizer = optimizer_cls(
        vae.parameters(),
        lr=args.learning_rate,
        betas=(args.adam_beta1, args.adam_beta2),
        weight_decay=args.adam_weight_decay,
        eps=args.adam_epsilon,
    )

    # Get the datasets: you can either provide your own training and evaluation files (see below)
    # or specify a Dataset from the hub (the dataset will be downloaded automatically from the datasets Hub).

    # In distributed training, the load_dataset function guarantees that only one local process can concurrently
    # download the dataset.
    if args.dataset_name is not None:
        # Downloading and loading a dataset from the hub.
        dataset = load_dataset(
            args.dataset_name,
            args.dataset_config_name,
            cache_dir=args.cache_dir,
        )
    else:
        data_files = {}
        if args.train_data_dir is not None:
            data_files["train"] = os.path.join(args.train_data_dir, "**")
        dataset = load_dataset(
            "imagefolder",
            data_files=data_files,
            cache_dir=args.cache_dir,
        )

    column_names = dataset["train"].column_names
    if args.image_column is None:
        image_column = column_names[0]
    else:
        image_column = args.image_column
        if image_column not in column_names:
            raise ValueError(
                f"--image_column' value '{args.image_column}' needs to be one of: {', '.join(column_names)}"
            )

    train_transforms = transforms.Compose(
        [
            transforms.Resize(
                args.resolution, interpolation=transforms.InterpolationMode.BILINEAR
            ),
            transforms.RandomCrop(args.resolution),
            transforms.ToTensor(),
            transforms.Normalize([0.5], [0.5]),
        ]
    )

    # test_transforms = transforms.Compose(
    #     [
    #         transforms.Resize(
    #             args.resolution, interpolation=transforms.InterpolationMode.BILINEAR
    #         ),
    #         transforms.CenterCrop(args.resolution),
    #         transforms.ToTensor(),
    #         transforms.Normalize([0.5], [0.5]),
    #     ]
    # )

    def preprocess(examples):
        images = [image.convert("RGB") for image in examples[image_column]]
        examples["pixel_values"] = [train_transforms(image) for image in images]
        return examples

    # def test_preprocess(examples):
    #     images = [image.convert("RGB") for image in examples[image_column]]
    #     examples["pixel_values"] = [test_transforms(image) for image in images]
    #     return examples

    with accelerator.main_process_first():
        # Load test data from test_data_dir
        if(args.test_data_dir is not None and args.train_data_dir is not None):
            logger.info(f"load test data from {args.test_data_dir}")
            test_dir = os.path.join(args.test_data_dir, "**")        
            test_dataset = load_dataset(
                "imagefolder",
                data_files=test_dir,
                cache_dir=args.cache_dir,
            )
            # Set the training transforms
            train_dataset = dataset["train"].with_transform(preprocess)
            test_dataset = test_dataset["train"].with_transform(preprocess)
        # Load train/test data from train_data_dir
        elif "test" in dataset.keys():
            train_dataset = dataset["train"].with_transform(preprocess)
            test_dataset = dataset["test"].with_transform(preprocess)
        # Split into train/test
        else:
            dataset = dataset["train"].train_test_split(test_size=args.test_samples)        
            # Set the training transforms
            train_dataset = dataset["train"].with_transform(preprocess)
            test_dataset = dataset["test"].with_transform(preprocess)

    def collate_fn(examples):
        pixel_values = torch.stack([example["pixel_values"] for example in examples])
        pixel_values = pixel_values.to(memory_format=torch.contiguous_format).float()
        return {"pixel_values": pixel_values}

    # DataLoaders creation:
    train_dataloader = torch.utils.data.DataLoader(
        train_dataset,
        shuffle=True,
        collate_fn=collate_fn,
        batch_size=args.train_batch_size,
        num_workers=args.train_batch_size*accelerator.num_processes,
    )

    test_dataloader = torch.utils.data.DataLoader(
        test_dataset, shuffle=False, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=1,#args.train_batch_size*accelerator.num_processes,
    )

    lr_scheduler = get_scheduler(
        args.lr_scheduler,
        optimizer=optimizer,
        num_warmup_steps=args.lr_warmup_steps * args.gradient_accumulation_steps,
        num_training_steps=args.num_train_epochs * args.gradient_accumulation_steps,
    )

    # Prepare everything with our `accelerator`.

    (vae,
        vae.encoder,
        vae.decoder,
        optimizer,
        train_dataloader,
        test_dataloader,
        lr_scheduler,
    ) = accelerator.prepare(
        vae,vae.encoder, vae.decoder, optimizer, train_dataloader, test_dataloader, lr_scheduler
    )

    if args.use_ema:
        ema_vae.to(accelerator.device)

    weight_dtype = torch.float32
    if accelerator.mixed_precision == "fp16":
        weight_dtype = torch.float16
    elif accelerator.mixed_precision == "bf16":
        weight_dtype = torch.bfloat16

    # We need to initialize the trackers we use, and also store our configuration.
    # The trackers initializes automatically on the main process.
    if accelerator.is_main_process:
        tracker_config = dict(vars(args))
        accelerator.init_trackers(args.tracker_project_name, tracker_config)

    num_update_steps_per_epoch = math.ceil(
        len(train_dataloader) / args.gradient_accumulation_steps
    )
    args.max_train_steps = args.num_train_epochs * num_update_steps_per_epoch

    # ------------------------------ TRAIN ------------------------------ #
    total_batch_size = (
        args.train_batch_size
        * accelerator.num_processes
        * args.gradient_accumulation_steps
    )

    logger.info("***** Running training *****")
    logger.info(f"  Num examples = {len(train_dataset)}")
    logger.info(f"  Num test samples = {len(test_dataset)}")
    logger.info(f"  Num Epochs = {args.num_train_epochs}")
    logger.info(f"  Instantaneous batch size per device = {args.train_batch_size}")
    logger.info(
        f"  Total train batch size (w. parallel, distributed & accumulation) = {total_batch_size}"
    )
    logger.info(f"  Gradient Accumulation steps = {args.gradient_accumulation_steps}")
    global_step = 0
    first_epoch = 0

    # Potentially load in the weights and states from a previous save
    if args.resume_from_checkpoint:
        if args.resume_from_checkpoint != "latest":
            path = os.path.basename(args.resume_from_checkpoint)
        else:
            # Get the most recent checkpoint
            # dirs = os.listdir(args.output_dir)
            dirs = os.listdir(args.resume_from_checkpoint)
            dirs = [d for d in dirs if d.startswith("checkpoint")]
            dirs = sorted(dirs, key=lambda x: int(x.split("-")[1]))
            path = dirs[-1] if len(dirs) > 0 else None

        if path is None:
            print(
                f"Checkpoint '{args.resume_from_checkpoint}' does not exist. Starting a new training run."
            )
            args.resume_from_checkpoint = None
        else:
            print(f"Resuming from checkpoint {path}")
            # accelerator.load_state(os.path.join(args.output_dir, path))
            accelerator.load_state(os.path.join(path)) #kiml
            global_step = int(path.split("-")[1])

            resume_global_step = global_step * args.gradient_accumulation_steps
            first_epoch = global_step // num_update_steps_per_epoch
            resume_step = resume_global_step % (num_update_steps_per_epoch * args.gradient_accumulation_steps)

    progress_bar = tqdm(
        range(global_step, args.max_train_steps),
        disable=not accelerator.is_local_main_process,
    )
    progress_bar.set_description("Steps")

    lpips_loss_fn = lpips.LPIPS(net="alex").to(accelerator.device, dtype=weight_dtype)
    lpips_loss_fn.requires_grad_(False)

    for epoch in range(first_epoch, args.num_train_epochs):
        vae.train()
        train_loss = 0.0
        logger.info(f"{epoch = }")

        for step, batch in enumerate(train_dataloader):
            # Skip steps until we reach the resumed step
            if args.resume_from_checkpoint and epoch == first_epoch and step < resume_step:
                if step % args.gradient_accumulation_steps == 0:
                    progress_bar.update(1)
                continue
            with accelerator.accumulate(vae):
                target = batch["pixel_values"].to(weight_dtype)

                # https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/autoencoder_kl.py
                if accelerator.num_processes > 1:
                    posterior = vae.module.encode(target).latent_dist
                else:
                    posterior = vae.encode(target).latent_dist

                # z = mean                      if posterior.mode()
                # z = mean + variable*epsilon   if posterior.sample()
                z = posterior.sample() # Not mode()
                if accelerator.num_processes > 1:
                    pred = vae.module.decode(z).sample
                else:
                    pred = vae.decode(z).sample

                kl_loss = posterior.kl().mean()

                # if global_step > args.mse_start:
                #     pixel_loss = F.mse_loss(pred.float(), target.float(), reduction="mean")
                # else:
                #     pixel_loss = F.mse_loss(pred.float(), target.float(), reduction="mean")

                mse_loss = F.mse_loss(pred.float(), target.float(), reduction="mean")

                with torch.no_grad():
                    lpips_loss = lpips_loss_fn(pred.to(dtype=weight_dtype), target).mean()
                    if not torch.isfinite(lpips_loss):
                        lpips_loss = torch.tensor(0)

                loss = (
                    mse_loss + args.lpips_scale * lpips_loss + args.kl_scale * kl_loss
                )

                if not torch.isfinite(loss):
                    pred_mean = pred.mean()
                    target_mean = target.mean()
                    logger.info("\nWARNING: non-finite loss, ending training ")

                # Gather the losses across all processes for logging (if we use distributed training).
                avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean()
                train_loss += avg_loss.item() / args.gradient_accumulation_steps

                accelerator.backward(loss)
                if accelerator.sync_gradients:
                    accelerator.clip_grad_norm_(vae.parameters(), args.max_grad_norm)
                optimizer.step()
                lr_scheduler.step()
                optimizer.zero_grad()

            # Checks if the accelerator has performed an optimization step behind the scenes
            if accelerator.sync_gradients:
                if args.use_ema:
                    ema_vae.step(vae.parameters())
                progress_bar.update(1)
                global_step += 1
                accelerator.log({"train_loss": train_loss}, step=global_step)
                train_loss = 0.0

                if global_step % args.checkpointing_steps == 0:
                    if accelerator.is_main_process:
                        save_path = os.path.join(
                            args.output_dir, f"checkpoint-{global_step}"
                        )
                        accelerator.save_state(save_path)
                        logger.info(f"Saved state to {save_path}")

            logs = {
                "step_loss": loss.detach().item(),
                "lr": lr_scheduler.get_last_lr()[0],
                "mse": mse_loss.detach().item(),
                "lpips": lpips_loss.detach().item(),
                "kl": kl_loss.detach().item(),
            }
            accelerator.log(logs)
            progress_bar.set_postfix(**logs)

        if accelerator.is_main_process:
            if epoch % args.validation_epochs == 0:
                with torch.no_grad():
                    log_validation(args, repo_id, test_dataloader, vae, accelerator, weight_dtype, epoch)

    # Create the pipeline using the trained modules and save it.
    accelerator.wait_for_everyone()
    if accelerator.is_main_process:
        vae = accelerator.unwrap_model(vae)
        if args.use_ema:
            ema_vae.copy_to(vae.parameters())
        vae.save_pretrained(args.output_dir)

    accelerator.end_training()

if __name__ == "__main__":
    # torch.autograd.set_detect_anomaly(True)
    main()
yeonsikch commented 8 months ago

@sapkun

And if you train VAE and use it with stable diffusion, you should definitely learn only the decoder part.

jiangyuhangcn commented 8 months ago

Reference

hi!can u finetune vae with fp16?? thanks!

linnanwang commented 7 months ago

@yeonsikch thanks for the great work, however I met the problem of creating a negative tensor in running your example code above, see below: image

Any potential solutions? Thanks.

humanely commented 6 months ago

I am a bit confused with HF AutoencoderKL after training CompVis AutoEncoder. There is no tokenizer in HF, whereas the CompVis uses a tokenizer. Basically, I am training a new language, so I need to use a custom tokenizer. Why is the difference in Autoencoder?

yeonsikch commented 6 months ago

@linnanwang Im sorry. I don't know too. I think that version issue (torch version)

yeonsikch commented 6 months ago

@humanely AutoencoderKL works that image encde latent space. Therefore, you don't need a tokenizer when you train autoencoderKL(=vae).

GiilDe commented 3 months ago

Hi @trouble-maker007 @ThibaultCastells , I did a try to train a LDM model based on both VAE, the one with sampling, the one without. Bellow the result for the same amount of iterations :

image

image

For me it confirms that it's better to train with uncertainty. Anyway, the issues i face to make the VAE converge remains on generated images.

You can find my training script there : https://github.com/FrsECM/diffusers/blob/add-semantic-diffusion/examples/community/semantic_image_synthesis/train_vae_ldm.py

Can you elaborate about this? This is the result of left image: diffusiom model sampling and then feeding to decoder trained with noise and right image: diffusiom model sampling and then feeding to decoder trained without noise? and scale = 7.5 is the CFG scale?

kukaiN commented 3 months ago

@linnanwang @jiangyuhangcn I made a "fork" of the code here: https://github.com/kukaiN/vae_finetune/tree/main

I also had the same issue with mixed precision (I wanted to use bf16) and negative dimension (caused by mismatching precisions), so I made some modification. I also added xformers to the code. The changes are listed in the readme, but tldr force initializing the trainable weights and using autocast in the training loop fixes the code to run mixed precision

KimbingNg commented 3 months ago

@kukaiN Thanks! Good fix. Can you elaborate more about the cause (mismatching precisions)? In my code, I didn't use accelerator. Both the model weights and the inputs are in float32. The forward are called under with torch.cuda.amp.autocast(), but the same issue still raises. I also found that removing the attention blocks in the mid_block can make my code run perfectly. So I believe it is the attention operations in the mid_block that leads to this error (only during mixed precision training). Can you help me with that?

kukaiN commented 3 months ago

@KimbingNg I just want to confirm if you reloaded the weights to float32 after the initial loading and the autocast scope contains the forward up to the backpropagation, like the snippet below.

My suspicion is that the mixed precision error happens because a part of the model is not properly casted. I made the changes based on the linked question/discussion, but I didn't pinpoint which layer is causing the problem.

# line 413 ~ 432:
# we load it with float32 here, but we cast it again right after
vae = AutoencoderKL.from_pretrained(model_path, ..., torch_dtype=torch.float32)

vae.requires_grad_(True)

# https://stackoverflow.com/questions/75802877/issues-when-using-huggingface-accelerate-with-fp16
# load params with fp32, which is auto casted later to mixed precision, may be needed for ema
#
# from stackoverflow's answer, it links to diffuser's sdxl training script example and in that code there's another link
# which points to https://github.com/huggingface/diffusers/pull/6514#discussion_r1447020705
# which may suggest we need to do all this casting before passing the learnable params to the optimizer

  for param in vae.parameters():
          if param.requires_grad:
              param.data = param.to(torch.float32)

...

vae.to(dtype=weight_dtype) #weight_dtype is fp16 or bf16

...

# training loop:
#  with autocast(): 
#     forward process
#      ...
#     backpropagate the loss