lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.03k stars 1.07k forks source link

NameError: name 'text_mask' is not defined when running default example (m1, cpu) #28

Closed lobziq closed 2 years ago

lobziq commented 2 years ago

so basically i took example code and modified it to use CPU instead of cuda (m1 mac)

from dalle2_pytorch import DALLE2, DiffusionPriorNetwork, DiffusionPrior, Unet, Decoder, CLIP

clip = CLIP(
    dim_text=512,
    dim_image=512,
    dim_latent=512,
    num_text_tokens=49408,
    text_enc_depth=6,
    text_seq_len=256,
    text_heads=8,
    visual_enc_depth=6,
    visual_image_size=256,
    visual_patch_size=32,
    visual_heads=8
).cpu()

# mock data

text = torch.randint(0, 49408, (4, 256)).cpu()
images = torch.randn(4, 3, 256, 256).cpu()

# train

loss = clip(
    text,
    images,
    return_loss=True
)

loss.backward()

# do above for many steps ...

# prior networks (with transformer)

prior_network = DiffusionPriorNetwork(
    dim=512,
    depth=6,
    dim_head=64,
    heads=8
).cpu()

diffusion_prior = DiffusionPrior(
    net=prior_network,
    clip=clip,
    timesteps=100,
    cond_drop_prob=0.2
).cpu()

loss = diffusion_prior(text, images)
loss.backward()

# do above for many steps ...

# decoder (with unet)

unet1 = Unet(
    dim=128,
    image_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8)
).cpu()

unet2 = Unet(
    dim=16,
    image_embed_dim=512,
    cond_dim=128,
    channels=3,
    dim_mults=(1, 2, 4, 8, 16)
).cpu()

decoder = Decoder(
    unet=(unet1, unet2),
    image_sizes=(128, 256),
    clip=clip,
    timesteps=100,
    cond_drop_prob=0.2,
    condition_on_text_encodings=False  # set this to True if you wish to condition on text during training and sampling
).cpu()

for unet_number in (1, 2):
    loss = decoder(images,
                   unet_number=unet_number)  # this can optionally be decoder(images, text) if you wish to condition on the text encodings as well, though it was hinted in the paper it didn't do much
    loss.backward()

# do above for many steps

dalle2 = DALLE2(
    prior=diffusion_prior,
    decoder=decoder
)

images = dalle2(
    ['cute puppy chasing after a squirrel'],
    cond_scale=2.)

# save your image (in this example, of size 256x256)

ive expected it to run but there is a error

File "dalle2_pytorch.py", line 746, in sample
    text_cond = {**text_cond, 'text_encodings': text_encodings, 'mask': text_mask}
NameError: name 'text_mask' is not defined
lucidrains commented 2 years ago

@lobziq oops, fixed here https://github.com/lucidrains/DALLE2-pytorch/commit/625ce23f6b5e91b4fc75464d63ab5ee6a5b7c011 :pray: