sangminkim-99 / Sketch-Guided-Text-To-Image

Unofficial implementation of Sketch-Guided Text-to-Image Diffusion Models
10 stars 0 forks source link

Issue with Training LEP with batch_size > 1 #4

Open rs2125 opened 7 months ago

rs2125 commented 7 months ago

Why is batch_size*2 passed as a parameter in the following lines instead of batch_size?

In train_LEP.py:

noisy_image, noise_level, timesteps = noisy_latent(latent_image, pipe.scheduler, batch_size * 2, num_train_timestep)

pred_edge_map = rearrange(pred_edge_map, "(b w h) c -> b c h w", b=batch_size * 2, h=latent_edge.shape[2], w=latent_edge.shape[3])

Moreover, why is the parameter torch.cat([latent_image] * 2) instead of just latent_image in the following line: pipe.unet(torch.cat([latent_image] * 2), timesteps, encoder_hidden_states=caption_embedding)

rs2125 commented 7 months ago

Made some changes to the scripts which solves the issue of size mismatch between pred_edge_map and latent_edge variables. Now we only need to pass batch_size as a parameter to functions instead of batch_size*2. *The issue regarding passing `torch.cat([latent_image] 2)instead of justlatent_imagein theunetstill remains unresolved**. The **batch_size > 1`** issue still remains unresolved.

UPDATED train_LEP.py

import os
import math
from diffusers import StableDiffusionPipeline
from einops import rearrange
import numpy as np
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import typer
from typing import List
from typing_extensions import Annotated

from internals.diffusion_utils import encode_img, encode_text, hook_unet, noisy_latent
from internals.latent_edge_predictor import LatentEdgePredictor
from internals.LEP_dataset import LEPDataset

def train_LEP(
    model_id: Annotated[str, typer.Option()] = "CompVis/stable-diffusion-v1-4",
    device: Annotated[str, typer.Option()] = "cuda:1",
    dataset_dir: Annotated[str, typer.Option(help="path to the parent directory of image data")] = "./data/imagenet/imagenet_images",
    edge_map_dir: Annotated[str, typer.Option(help="path to the parent directory of edge map data")] = "./data/imagenet/edge_maps",
    save_path: Annotated[str, typer.Option(help="path to save LEP model")] = "./output/LEP.pt",
    batch_size: Annotated[int, typer.Option(help="batch size for training LEP. Decrease this if OOM occurs.")] = 1,
    training_step: Annotated[int, typer.Option()] = 4633,
    lr: Annotated[float, typer.Option()] = 1e-4, # not specified in the paper
    num_train_timestep: Annotated[int, typer.Option(help="maximum diffusion timestep")] = 250, # not specified in the paper
):
    '''
    Train the Latent Edge Predictor.
    '''
    # create output folder
    os.makedirs(os.path.dirname(save_path), exist_ok=True)

    # create dataset & loader
    dataset = LEPDataset(dataset_dir, edge_map_dir)
    dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)

    # initialize stable diffusion pipeline.
    # the paper use stable-diffusion-v1.4
    pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None, requires_safety_checker = False).to(device)

    unet = pipe.unet
    unet.enable_xformers_memory_efficient_attention()

    # hook the feature_blocks of unet
    feature_blocks = hook_unet(pipe.unet)       

    # initialize LEP
    LEP = LatentEdgePredictor(input_dim=9324, output_dim=4, num_layers=10).to(device)

    pipe.unet.eval()
    pipe.vae.eval()
    pipe.text_encoder.eval()

    # need this lines?
    pipe.unet.requires_grad_(False)
    pipe.text_encoder.requires_grad_(False)
    LEP.requires_grad_(True)

    # load clip tokenizer
    tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")

    optimizer = torch.optim.Adam(LEP.parameters(), lr=lr)
    criterion = torch.nn.MSELoss()

    train_epochs = 10
    max_train_steps = train_epochs * len(dataloader)
    num_update_steps_per_epoch = len(dataloader)
    num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
    progress_bar = tqdm(
        range(1, max_train_steps),
        smoothing=0,
        desc="steps",
        position=0, leave=True
    )

    for epoch in range(num_train_epochs):
        progress_bar.set_description_str(f"Epoch {epoch+1}/{num_train_epochs}")
        loss_total = 0
        for step, batch in enumerate(dataloader):
            image, edge_map, caption = batch[0], batch[1], batch[2]
            optimizer.zero_grad()

            # image to latent
            latent_image = encode_img(pipe.vae, image)
            latent_edge = encode_img(pipe.vae, edge_map)
            latent_edge = latent_edge.transpose(1,3)

            caption_embedding = torch.cat([encode_text(pipe.text_encoder, tokenizer, c) for c in caption])
            noisy_image, noise_level, timesteps = noisy_latent(latent_image, pipe.scheduler, batch_size , num_train_timestep)

            # one reverse step to get the feature blocks
            pipe.unet(torch.cat([latent_image] * 2), timesteps, encoder_hidden_states=caption_embedding)

            activations = []
            for block in feature_blocks:
                activations.append(block.output)
                block.output = None

            features = activations

            assert all([isinstance(acts, torch.Tensor) for acts in features])
            size = latent_image.shape[2:]
            resized_activations = []
            for acts in features:
                acts = torch.nn.functional.interpolate(acts, size=size, mode="bilinear")
                acts = acts[:1]
                acts = acts.transpose(1,3)
                resized_activations.append(acts)

            intermediate_result = torch.cat(resized_activations, dim=3)
            intermediate_result = intermediate_result.transpose(1,3)

            pred_edge_map = LEP(intermediate_result, noise_level)
            pred_edge_map = rearrange(pred_edge_map, "(b w h) c -> b h w c", b=batch_size, h=latent_edge.shape[1], w=latent_edge.shape[2])

            # calculate MSE loss
            loss = criterion(pred_edge_map, latent_edge)
            loss.backward()

            optimizer.step()

            current_loss = loss.detach().item()
            loss_total += current_loss
            avr_loss = loss_total / (step + 1)

            if step % 10 == 0:
                progress_bar.set_description(f"Loss: {avr_loss:.3f}")

            if step >= max_train_steps:
                break

            step += 1

        if step >= training_step:
            print(f'Finish to optimize. Save file to {save_path}, Epoch = {epoch+1}')
            path = "./output/LEP-" + str(epoch+1) + ".pt"
            torch.save(LEP.state_dict(), path)

Updated internals/diffusion_utils.py

from diffusers import AutoencoderKL, UNet2DConditionModel
import torch
from transformers.models.clip import CLIPTextModel, CLIPTokenizer

def encode_img(vae: AutoencoderKL, image: torch.Tensor):
    generator = torch.Generator(vae.device).manual_seed(0)
    latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample(generator=generator)
    latents = latents * 0.18215
    return latents

def encode_text(text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text):
    text_input = tokenizer([text], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
    with torch.no_grad():
        text_embeddings = text_encoder(text_input.input_ids.to(text_encoder.device))[0]
    max_length = text_input.input_ids.shape[-1]
    uncond_input = tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
    with torch.no_grad():
        uncond_embeddings = text_encoder(uncond_input.input_ids.to(text_encoder.device))[0]   
    # return torch.cat([uncond_embeddings, text_embeddings]).unsqueeze(0)
    return torch.cat([uncond_embeddings, text_embeddings])

def noisy_latent(image, noise_scheduler, batch_size, num_train_timestep):
    timesteps = torch.randint(0, num_train_timestep, (batch_size,), dtype=torch.int64, device=image.device).long()
    noise = torch.randn_like(image, device=image.device)

    alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps.cpu()].to(image.device)
    # print("alpha_prod = ", alphas_cumprod)
    sqrt_alpha_prod = alphas_cumprod ** 0.5
    sqrt_alpha_prod = sqrt_alpha_prod.flatten()

    while len(sqrt_alpha_prod.shape) < len(image.shape):
            sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)

    sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5
    sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()

    while len(sqrt_one_minus_alpha_prod.shape) < len(image.shape):
        sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)

    noisy_samples = sqrt_alpha_prod * image + sqrt_one_minus_alpha_prod * noise
    noise_level = noisy_samples - (sqrt_alpha_prod * image)

    return noisy_samples, noise_level, timesteps

def hook_unet(unet: UNet2DConditionModel):
    blocks_idx = [0, 1, 2]
    feature_blocks = []
    def hook(module, input, output):
        if isinstance(output, tuple):
            output = output[0]

        if isinstance(output, torch.TensorType):
            feature = output.float()
            setattr(module, "output", feature)
        elif isinstance(output, dict): 
            feature = output.sample.float()
            setattr(module, "output", feature)
        else: 
            feature = output.float()
            setattr(module, "output", feature)

    # TODO: Check below lines are correct

    # 0, 1, 2 -> (ldm-down) 2, 4, 8
    for idx, block in enumerate(unet.down_blocks):
        if idx in blocks_idx:
            block.register_forward_hook(hook)
            feature_blocks.append(block) 

    # ldm-mid 0, 1, 2
    for block in unet.mid_block.attentions + unet.mid_block.resnets:
        block.register_forward_hook(hook)
        feature_blocks.append(block) 

    # 0, 1, 2 -> (ldm-up) 2, 4, 8
    for idx, block in enumerate(unet.up_blocks):
        if idx in blocks_idx:
            block.register_forward_hook(hook)
            feature_blocks.append(block)

    return feature_blocks
shaoke317 commented 4 months ago

Hello, how many batches can you use to train LEP now? Because I saw in the author's code that batch size is set to 16. When batchresize was set to 16:00, it showed my CUDA out of memory. I tried to change it to 8, but it showed a dimension mismatch. How should we change it? Can you give me some suggestions 屏幕截图 2024-07-16 154100 屏幕截图 2024-07-16 154159