Open bahmanian opened 1 year ago
Here is a sample code for training
import torch
from dalle2_pytorch.tokenizer import SimpleTokenizer
from torch.utils.data import DataLoader
from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, OpenAIClipAdapter, Unet, Decoder, \
DecoderTrainer
from torchvision.utils import make_grid
import torchvision.transforms as T
from torchvision.utils import save_image
from PIL import Image
from datetime import datetime
import os
import torch.utils.data as data
import json
from torchvision.utils import save_image
def read_metadata(text: str) -> list[dict]:
data = []
for line in text.split('\n'):
if not line:
continue
line_json = json.loads(line)
data.append(line_json)
return data
# ImgTextDataset returns image tensor and text caption for it. Works as huggingface text-to-image dataset implementation.
# __init__ reads info about dataset from `metadata.jsonl` file where image paths and captions are specified.
# {"file_name": "/path/1.png", "text": "sample text 1"}
# {"file_name": "/path/2.png", "text": "sample text 2"}
# ...
# To make custom dataset inherit data.Dataset and implement __len__ and __getitem__ methods.
class ImgTextDataset(data.Dataset):
def __init__(self, fp: str):
self.fp = fp
with open(os.path.join(fp, 'metadata.jsonl'), 'r') as file:
metadata = read_metadata(file.read())
self.img_paths = []
self.captions = []
for line in metadata:
self.img_paths.append(line['file_name'])
self.captions.append(line['text'])
# Make sure that each image is captioned
assert len(self.img_paths) == len(self.captions)
# Apply required image transforms. For my model I need RGB images with 256 x 256 dimensions.
self.image_tranform = T.Compose([
T.Lambda(lambda img: img.convert('RGB') if img.mode != 'RGB' else img),
T.Resize((256, 256)),
T.ToTensor()
])
def __len__(self):
return len(self.img_paths)
def __getitem__(self, idx):
image_path = os.path.join(self.fp, self.img_paths[idx])
caption = self.captions[idx]
image = Image.open(image_path)
image_pt = self.image_tranform(image).cuda()
return image_pt, caption
# Parameters
image_size = 256 # Image dimension
batch_size = 1 # Batch size for training, adjust based on GPU memory
learning_rate = 1e-4 # Learning rate for the optimizer
num_epochs = 50 # Number of epochs for training
log_image_interval = 1000 # Interval for logging images
save_dir = "./log_images" # Directory to save log images
os.makedirs(save_dir, exist_ok=True) # Create save directory if it doesn't exist
# Setup device
device = torch.device("cuda") # Not recommended to train on cpu
# Define your image-text dataset
dataset = ImgTextDataset('path to folder with metadata.jsonl file')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Initialize OpenAI CLIP model adapter
clip = OpenAIClipAdapter()
# Create models for training
unet1 = Unet(
dim=128,
image_embed_dim=512,
text_embed_dim=512,
cond_dim=128,
channels=3,
dim_mults=(1, 2, 4, 8),
cond_on_text_encodings=True,
).cuda()
decoder = Decoder(
unet=unet1,
image_size=image_size,
clip=clip,
timesteps=1000
).cuda()
decoder_trainer = DecoderTrainer(
decoder,
lr=3e-4,
wd=1e-2,
ema_beta=0.99,
ema_update_after_step=1000,
ema_update_every=10,
).cuda()
# Use built-in tokenizer. You can use others like GPT2, YTTM etc.
t = SimpleTokenizer()
# Training loop.
# Iterate over the dataloader and pass image tensors and tokenized text to the training wrapper.
# Repeat process N times.
for epoch in range(num_epochs):
for batch_idx, (images, texts) in enumerate(dataloader):
loss = decoder_trainer(
images.cuda(),
text=t.tokenize(texts).cuda(),
unet_number=1,
max_batch_size=4
)
decoder_trainer.update(1)
if batch_idx % 100 == 0:
print(f"epoch {epoch}, step {batch_idx}, loss {loss}")
if batch_idx % 5000 == 0 and batch_idx != 0:
image_embed = clip.embed_image(images.cuda())
sample = decoder_trainer.sample(image_embed=image_embed[0], text=t.tokenize(texts).cuda())
save_image(sample, f'./log_images/{epoch}_{batch_idx}.png')
# Periodically save the model.
torch.save(decoder_trainer.state_dict(), f'model_{epoch}.pt')
Hi @u1ug , do you also have a similar example to train a prior?
Please someone explain how to use it exactly?