lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.17k stars 1.09k forks source link

Setup Steps #279

Open bahmanian opened 1 year ago

bahmanian commented 1 year ago

Please someone explain how to use it exactly?

u1ug commented 1 year ago

Here is a sample code for training

import torch
from dalle2_pytorch.tokenizer import SimpleTokenizer
from torch.utils.data import DataLoader
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, OpenAIClipAdapter, Unet, Decoder, \
    DecoderTrainer
from torchvision.utils import make_grid
import torchvision.transforms as T
from torchvision.utils import save_image
from PIL import Image
from datetime import datetime
import os
import torch.utils.data as data
import json
from torchvision.utils import save_image

def read_metadata(text: str) -> list[dict]:
    data = []
    for line in text.split('\n'):
        if not line:
            continue
        line_json = json.loads(line)
        data.append(line_json)
    return data

# ImgTextDataset returns image tensor and text caption for it. Works as huggingface text-to-image dataset implementation.
# __init__ reads info about dataset from `metadata.jsonl` file where image paths and captions are specified.
# {"file_name": "/path/1.png", "text": "sample text 1"}
# {"file_name": "/path/2.png", "text": "sample text 2"}
# ...
# To make custom dataset inherit data.Dataset and implement __len__ and __getitem__ methods.
class ImgTextDataset(data.Dataset):
    def __init__(self, fp: str):
        self.fp = fp
        with open(os.path.join(fp, 'metadata.jsonl'), 'r') as file:
            metadata = read_metadata(file.read())

        self.img_paths = []
        self.captions = []

        for line in metadata:
            self.img_paths.append(line['file_name'])
            self.captions.append(line['text'])

        # Make sure that each image is captioned
        assert len(self.img_paths) == len(self.captions)
        # Apply required image transforms. For my model I need RGB images with 256 x 256 dimensions.
        self.image_tranform = T.Compose([
            T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
            T.Resize((256, 256)),
            T.ToTensor()
        ])

    def __len__(self):
        return len(self.img_paths)

    def __getitem__(self, idx):
        image_path = os.path.join(self.fp, self.img_paths[idx])
        caption = self.captions[idx]

        image = Image.open(image_path)

        image_pt = self.image_tranform(image).cuda()
        return image_pt, caption

# Parameters
image_size = 256  # Image dimension
batch_size = 1  # Batch size for training, adjust based on GPU memory
learning_rate = 1e-4  # Learning rate for the optimizer
num_epochs = 50  # Number of epochs for training
log_image_interval = 1000  # Interval for logging images
save_dir = "./log_images"  # Directory to save log images
os.makedirs(save_dir, exist_ok=True)  # Create save directory if it doesn't exist

# Setup device
device = torch.device("cuda")  # Not recommended to train on cpu

# Define your image-text dataset
dataset = ImgTextDataset('path to folder with metadata.jsonl file')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# Initialize OpenAI CLIP model adapter
clip = OpenAIClipAdapter()
# Create models for training
unet1 = Unet(
    dim=128,
    image_embed_dim=512,
    text_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8),
    cond_on_text_encodings=True,
).cuda()

decoder = Decoder(
    unet=unet1,
    image_size=image_size,
    clip=clip,
    timesteps=1000
).cuda()

decoder_trainer = DecoderTrainer(
    decoder,
    lr=3e-4,
    wd=1e-2,
    ema_beta=0.99,
    ema_update_after_step=1000,
    ema_update_every=10,
).cuda()

# Use built-in tokenizer. You can use others like GPT2, YTTM etc.
t = SimpleTokenizer()

# Training loop.
# Iterate over the dataloader and pass image tensors and tokenized text to the training wrapper.
# Repeat process N times.

for epoch in range(num_epochs):
    for batch_idx, (images, texts) in enumerate(dataloader):
        loss = decoder_trainer(
            images.cuda(),
            text=t.tokenize(texts).cuda(),
            unet_number=1,
            max_batch_size=4
        )
        decoder_trainer.update(1)
        if batch_idx % 100 == 0:
            print(f"epoch {epoch}, step {batch_idx}, loss {loss}")
        if batch_idx % 5000 == 0 and batch_idx != 0:
            image_embed = clip.embed_image(images.cuda())
            sample = decoder_trainer.sample(image_embed=image_embed[0], text=t.tokenize(texts).cuda())
            save_image(sample, f'./log_images/{epoch}_{batch_idx}.png')
    # Periodically save the model.
    torch.save(decoder_trainer.state_dict(), f'model_{epoch}.pt')
Felix-FN commented 7 months ago

Hi @u1ug , do you also have a similar example to train a prior?