Open rs2125 opened 7 months ago
Made some changes to the scripts which solves the issue of size mismatch between pred_edge_map and latent_edge
variables. Now we only need to pass batch_size
as a parameter to functions instead of batch_size*2
. *The issue regarding passing `torch.cat([latent_image] 2)instead of just
latent_imagein the
unetstill remains unresolved**. The **
batch_size > 1`** issue still remains unresolved.
UPDATED train_LEP.py
import os
import math
from diffusers import StableDiffusionPipeline
from einops import rearrange
import numpy as np
import torch
from tqdm import tqdm
from transformers import CLIPTokenizer
import typer
from typing import List
from typing_extensions import Annotated
from internals.diffusion_utils import encode_img, encode_text, hook_unet, noisy_latent
from internals.latent_edge_predictor import LatentEdgePredictor
from internals.LEP_dataset import LEPDataset
def train_LEP(
model_id: Annotated[str, typer.Option()] = "CompVis/stable-diffusion-v1-4",
device: Annotated[str, typer.Option()] = "cuda:1",
dataset_dir: Annotated[str, typer.Option(help="path to the parent directory of image data")] = "./data/imagenet/imagenet_images",
edge_map_dir: Annotated[str, typer.Option(help="path to the parent directory of edge map data")] = "./data/imagenet/edge_maps",
save_path: Annotated[str, typer.Option(help="path to save LEP model")] = "./output/LEP.pt",
batch_size: Annotated[int, typer.Option(help="batch size for training LEP. Decrease this if OOM occurs.")] = 1,
training_step: Annotated[int, typer.Option()] = 4633,
lr: Annotated[float, typer.Option()] = 1e-4, # not specified in the paper
num_train_timestep: Annotated[int, typer.Option(help="maximum diffusion timestep")] = 250, # not specified in the paper
):
'''
Train the Latent Edge Predictor.
'''
# create output folder
os.makedirs(os.path.dirname(save_path), exist_ok=True)
# create dataset & loader
dataset = LEPDataset(dataset_dir, edge_map_dir)
dataloader = torch.utils.data.DataLoader(dataset, batch_size=batch_size, shuffle=True)
# initialize stable diffusion pipeline.
# the paper use stable-diffusion-v1.4
pipe = StableDiffusionPipeline.from_pretrained(model_id, safety_checker=None, requires_safety_checker = False).to(device)
unet = pipe.unet
unet.enable_xformers_memory_efficient_attention()
# hook the feature_blocks of unet
feature_blocks = hook_unet(pipe.unet)
# initialize LEP
LEP = LatentEdgePredictor(input_dim=9324, output_dim=4, num_layers=10).to(device)
pipe.unet.eval()
pipe.vae.eval()
pipe.text_encoder.eval()
# need this lines?
pipe.unet.requires_grad_(False)
pipe.text_encoder.requires_grad_(False)
LEP.requires_grad_(True)
# load clip tokenizer
tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
optimizer = torch.optim.Adam(LEP.parameters(), lr=lr)
criterion = torch.nn.MSELoss()
train_epochs = 10
max_train_steps = train_epochs * len(dataloader)
num_update_steps_per_epoch = len(dataloader)
num_train_epochs = math.ceil(max_train_steps / num_update_steps_per_epoch)
progress_bar = tqdm(
range(1, max_train_steps),
smoothing=0,
desc="steps",
position=0, leave=True
)
for epoch in range(num_train_epochs):
progress_bar.set_description_str(f"Epoch {epoch+1}/{num_train_epochs}")
loss_total = 0
for step, batch in enumerate(dataloader):
image, edge_map, caption = batch[0], batch[1], batch[2]
optimizer.zero_grad()
# image to latent
latent_image = encode_img(pipe.vae, image)
latent_edge = encode_img(pipe.vae, edge_map)
latent_edge = latent_edge.transpose(1,3)
caption_embedding = torch.cat([encode_text(pipe.text_encoder, tokenizer, c) for c in caption])
noisy_image, noise_level, timesteps = noisy_latent(latent_image, pipe.scheduler, batch_size , num_train_timestep)
# one reverse step to get the feature blocks
pipe.unet(torch.cat([latent_image] * 2), timesteps, encoder_hidden_states=caption_embedding)
activations = []
for block in feature_blocks:
activations.append(block.output)
block.output = None
features = activations
assert all([isinstance(acts, torch.Tensor) for acts in features])
size = latent_image.shape[2:]
resized_activations = []
for acts in features:
acts = torch.nn.functional.interpolate(acts, size=size, mode="bilinear")
acts = acts[:1]
acts = acts.transpose(1,3)
resized_activations.append(acts)
intermediate_result = torch.cat(resized_activations, dim=3)
intermediate_result = intermediate_result.transpose(1,3)
pred_edge_map = LEP(intermediate_result, noise_level)
pred_edge_map = rearrange(pred_edge_map, "(b w h) c -> b h w c", b=batch_size, h=latent_edge.shape[1], w=latent_edge.shape[2])
# calculate MSE loss
loss = criterion(pred_edge_map, latent_edge)
loss.backward()
optimizer.step()
current_loss = loss.detach().item()
loss_total += current_loss
avr_loss = loss_total / (step + 1)
if step % 10 == 0:
progress_bar.set_description(f"Loss: {avr_loss:.3f}")
if step >= max_train_steps:
break
step += 1
if step >= training_step:
print(f'Finish to optimize. Save file to {save_path}, Epoch = {epoch+1}')
path = "./output/LEP-" + str(epoch+1) + ".pt"
torch.save(LEP.state_dict(), path)
Updated internals/diffusion_utils.py
from diffusers import AutoencoderKL, UNet2DConditionModel
import torch
from transformers.models.clip import CLIPTextModel, CLIPTokenizer
def encode_img(vae: AutoencoderKL, image: torch.Tensor):
generator = torch.Generator(vae.device).manual_seed(0)
latents = vae.encode(image.to(device=vae.device, dtype=vae.dtype)).latent_dist.sample(generator=generator)
latents = latents * 0.18215
return latents
def encode_text(text_encoder: CLIPTextModel, tokenizer: CLIPTokenizer, text):
text_input = tokenizer([text], padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
with torch.no_grad():
text_embeddings = text_encoder(text_input.input_ids.to(text_encoder.device))[0]
max_length = text_input.input_ids.shape[-1]
uncond_input = tokenizer([""], padding="max_length", max_length=max_length, return_tensors="pt")
with torch.no_grad():
uncond_embeddings = text_encoder(uncond_input.input_ids.to(text_encoder.device))[0]
# return torch.cat([uncond_embeddings, text_embeddings]).unsqueeze(0)
return torch.cat([uncond_embeddings, text_embeddings])
def noisy_latent(image, noise_scheduler, batch_size, num_train_timestep):
timesteps = torch.randint(0, num_train_timestep, (batch_size,), dtype=torch.int64, device=image.device).long()
noise = torch.randn_like(image, device=image.device)
alphas_cumprod = noise_scheduler.alphas_cumprod[timesteps.cpu()].to(image.device)
# print("alpha_prod = ", alphas_cumprod)
sqrt_alpha_prod = alphas_cumprod ** 0.5
sqrt_alpha_prod = sqrt_alpha_prod.flatten()
while len(sqrt_alpha_prod.shape) < len(image.shape):
sqrt_alpha_prod = sqrt_alpha_prod.unsqueeze(-1)
sqrt_one_minus_alpha_prod = (1 - alphas_cumprod) ** 0.5
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.flatten()
while len(sqrt_one_minus_alpha_prod.shape) < len(image.shape):
sqrt_one_minus_alpha_prod = sqrt_one_minus_alpha_prod.unsqueeze(-1)
noisy_samples = sqrt_alpha_prod * image + sqrt_one_minus_alpha_prod * noise
noise_level = noisy_samples - (sqrt_alpha_prod * image)
return noisy_samples, noise_level, timesteps
def hook_unet(unet: UNet2DConditionModel):
blocks_idx = [0, 1, 2]
feature_blocks = []
def hook(module, input, output):
if isinstance(output, tuple):
output = output[0]
if isinstance(output, torch.TensorType):
feature = output.float()
setattr(module, "output", feature)
elif isinstance(output, dict):
feature = output.sample.float()
setattr(module, "output", feature)
else:
feature = output.float()
setattr(module, "output", feature)
# TODO: Check below lines are correct
# 0, 1, 2 -> (ldm-down) 2, 4, 8
for idx, block in enumerate(unet.down_blocks):
if idx in blocks_idx:
block.register_forward_hook(hook)
feature_blocks.append(block)
# ldm-mid 0, 1, 2
for block in unet.mid_block.attentions + unet.mid_block.resnets:
block.register_forward_hook(hook)
feature_blocks.append(block)
# 0, 1, 2 -> (ldm-up) 2, 4, 8
for idx, block in enumerate(unet.up_blocks):
if idx in blocks_idx:
block.register_forward_hook(hook)
feature_blocks.append(block)
return feature_blocks
Hello, how many batches can you use to train LEP now? Because I saw in the author's code that batch size is set to 16. When batchresize was set to 16:00, it showed my CUDA out of memory. I tried to change it to 8, but it showed a dimension mismatch. How should we change it? Can you give me some suggestions
Why is
batch_size*2
passed as a parameter in the following lines instead ofbatch_size
?In
train_LEP.py
:noisy_image, noise_level, timesteps = noisy_latent(latent_image, pipe.scheduler, batch_size * 2, num_train_timestep)
pred_edge_map = rearrange(pred_edge_map, "(b w h) c -> b c h w", b=batch_size * 2, h=latent_edge.shape[2], w=latent_edge.shape[3])
Moreover, why is the parameter
torch.cat([latent_image] * 2)
instead of justlatent_image
in the following line:pipe.unet(torch.cat([latent_image] * 2), timesteps, encoder_hidden_states=caption_embedding)