lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.17k stars 1.09k forks source link

The resulting images are of poor quality #316

Open xiaomei2002 opened 7 months ago

xiaomei2002 commented 7 months ago

first,i tried to train the CUB200 dataset, but the picture is blurry and there is no bird shape. so i tried to train only one sample, the result was not good either.

xiaomei2002 commented 7 months ago

I used the following code to train this sunflower picture and its label "flower", the final prompt was also entered "flower", training 1000 rounds, but the result was not satisfactory, could someone please help me to look at the problem?

import torch
from torch.utils.data import DataLoader
from torchvision.transforms import ToPILImage
from dalle2_pytorch.tokenizer import SimpleTokenizer
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, DiffusionPriorTrainer, Unet, Decoder, \
    OpenAIClipAdapter, DecoderTrainer
from typing import List, Dict
from torchvision.utils import make_grid
import torchvision.transforms as T
from torchvision.utils import save_image
from PIL import Image
from datetime import datetime
import os
import torch.utils.data as data
import json
# openai pretrained clip - defaults to ViT-B/32
clip = OpenAIClipAdapter()

# # mock data
#
# text = torch.randint(0, 49408, (4, 256)).cuda()
# images = torch.randn(4, 3, 256, 256).cuda()

def read_metadata(text: str) -> List[Dict]:
    data = []
    for line in text.split('\n'):
        if not line:
            continue
        line_json = json.loads(line)
        data.append(line_json)
    return data

class ImgTextDataset(data.Dataset):
    def __init__(self, fp: str):
        self.fp = fp
        with open(fp, 'r') as file:
            metadata = read_metadata(file.read())

        self.img_paths = []
        self.captions = []

        for line in metadata:
            self.img_paths.append(line['url'])
            self.captions.append(line['caption'])

        # Make sure that each image is captioned
        assert len(self.img_paths) == len(self.captions)
        # Apply required image transforms. For my model I need RGB images with 256 x 256 dimensions.
        self.image_tranform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((256, 256)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image_path = os.path.join(self.fp, self.img_paths[idx])
        caption = self.captions[idx]

        image = Image.open(image_path)

        image_pt = self.image_tranform(image).cuda()
        return image_pt, caption

image_size = 256  # Image dimension
batch_size = 1  # Batch size for training, adjust based on GPU memory
learning_rate = 3e-4  # Learning rate for the optimizer
num_epochs = 1000 # Number of epochs for training
log_image_interval = 100 # Interval for logging images
save_dir = "./log_images"  # Directory to save log images
os.makedirs(save_dir, exist_ok=True)  # Create save directory if it doesn't exist
# prior networks (with transformer)

device = torch.device("cuda")  # Not recommended to train on cpu

# Define your image-text dataset
dataset = ImgTextDataset('/root/autodl-tmp/dalle2-train/data/output2.jsonl')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

t = SimpleTokenizer()

prior_network = DiffusionPriorNetwork(
    dim = 512,
    depth = 6,
    dim_head = 64,
    heads = 8
).cuda()

diffusion_prior = DiffusionPrior(
    net = prior_network,
    clip = clip,
    timesteps = 100,
    cond_drop_prob = 0.2
).cuda()

diffusion_prior_trainer = DiffusionPriorTrainer(
    diffusion_prior,
    lr = 3e-4,
    wd = 1e-2,
    ema_beta = 0.99,
    ema_update_after_step = 1000,
    ema_update_every = 10,
)
for epoch in range(num_epochs):
    for prior_batch_idx, (images, texts) in enumerate(dataloader):
        loss = diffusion_prior_trainer(
            image=images.cuda(),
            text=t.tokenize(texts).cuda(),
            max_batch_size=4
        )
        diffusion_prior_trainer.update()
        if prior_batch_idx % 500 == 0:
            print(f"prior epoch {epoch}, step {prior_batch_idx}, loss {loss}")
#torch.save(diffusion_prior_trainer.state_dict(), f'model_{epoch}.pth')

# do above for many steps ...

# decoder (with unet)
unet1 = Unet(
    dim=128,
    image_embed_dim=512,
    text_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings=True,
).cuda()

decoder = Decoder(
    unet=unet1,
    image_size=image_size,
    clip=clip,
    timesteps=1000
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr=5e-4,
    wd=1e-2,
    ema_beta=0.99,
    ema_update_after_step=1000,
    ema_update_every=10,
).cuda()

# Training loop.
# Iterate over the dataloader and pass image tensors and tokenized text to the training wrapper.
# Repeat process N times.

for epoch in range(num_epochs):
    for batch_idx, (images, texts) in enumerate(dataloader):
        loss = decoder_trainer(
            images.cuda(),
            text=t.tokenize(texts).cuda(),
            unet_number=1,
            max_batch_size=4
        )
        decoder_trainer.update(1)
        if batch_idx % 500 == 0:
            print(f"decoder epoch {epoch}, step {batch_idx}, loss {loss}")
    # if (epoch+1) % log_image_interval == 0 :
    #     image_embed = clip.embed_image(images.cuda())
    #     sample = decoder_trainer.sample(image_embed=image_embed[0], text=t.tokenize(texts).cuda())
    #     save_image(sample, f'./log_images/{epoch}_{batch_idx}.png')
    # Periodically save the model.
#torch.save(decoder_trainer.state_dict(), f'model_{epoch}.pth')

dalle2 = DALLE2(
    prior = diffusion_prior,
    decoder = decoder
).cuda()

gen_images = dalle2(
    ['flower'],
    cond_scale = 2. # classifier free guidance strength (> 1 would strengthen the condition)
).cuda()

print(gen_images.shape)

from torchvision.utils import save_image

save_image(gen_images, 'image.png')

for img in gen_images:
    img = ToPILImage()(img)
    img.show()

the data for training flower the generate image generate Could it be my hyperparameter Settings? hope to receive reply ! thanks !!!!! ^ ^

Siddharth-Latthe-07 commented 5 months ago

Try out using Stable Diffusion model, it is more powerful than DALLE-2, due to its noising and denoising feature, the clarity of the images are good, and also u can use nvidia GPU Can refer this:- https://colab.research.google.com/github/huggingface/notebooks/blob/main/diffusers/stable_diffusion.ipynb

chenrxi commented 3 months ago

How about your final training loss? I think the reason is related to training from scratch, because the dataset scale of cub200 is limited, (maybe 6000 in total)? @xiaomei2002

pankil25 commented 2 months ago

@xiaomei2002 are you able to solve it?