lucidrains / DALLE2-pytorch

Implementation of DALL-E 2, OpenAI's updated text-to-image synthesis neural network, in Pytorch
MIT License
11.12k stars 1.09k forks source link

Error related to text mask #203

Closed rom1504 closed 2 years ago

rom1504 commented 2 years ago
File "/fsx/dalle2/.dalle_env_38/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 2340, in p_losses
return forward_call(*input, **kwargs)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 1817, in forward
model_output = unet(
File "/fsx/dalle2/.dalle_env_38/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
text_keep_mask = text_mask & text_keep_mask
RuntimeError: The size of tensor a (0) must match the size of tensor b (26) at non-singleton dimension 0
return forward_call(*input, **kwargs)
File "/fsx/dalle2/DALLE2-pytorch/dalle2_pytorch/dalle2_pytorch.py", line 1817, in forward
text_keep_mask = text_mask & text_keep_mask
RuntimeError: The size of tensor a (0) must match the size of tensor b (26) at non-singleton dimension 0

Recently introduced but not sure which commit

lucidrains commented 2 years ago

@rom1504 oh oops, might be me :cry: https://github.com/lucidrains/DALLE2-pytorch/commit/e0835acca97ba75b226d6eb2023200a293929fbb may or may not fix it haha

rom1504 commented 2 years ago
Traceback (most recent call last):
  File "/fsx/dalle2/upsampler/train_decoder.py", line 610, in <module>
    main()
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 1130, in __call__
    return self.main(*args, **kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 1055, in main
    rv = self.invoke(ctx)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 1404, in invoke
    return ctx.invoke(self.callback, **ctx.params)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 760, in invoke
    return __callback(*args, **kwargs)
  File "/fsx/dalle2/upsampler/train_decoder.py", line 607, in main
    initialize_training(config, config_path=config_file_path)
  File "/fsx/dalle2/upsampler/train_decoder.py", line 593, in initialize_training
    train(dataloaders, decoder, accelerator,
  File "/fsx/dalle2/upsampler/train_decoder.py", line 362, in train
    loss = trainer.forward(img, **forward_params, unet_number=unet)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/trainer.py", line 107, in inner
    out = fn(model, *args, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/trainer.py", line 721, in forward
    loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/accelerate/utils/operations.py", line 487, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/amp/autocast_mode.py", line 12, in decorate_autocast
    return func(*args, **kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2543, in forward
    losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2358, in p_losses
    model_output = unet(
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1835, in forward
    text_keep_mask = text_mask & text_keep_mask
RuntimeError: The size of tensor a (0) must match the size of tensor b (60) at non-singleton dimension 0

didn't fix hmm

lucidrains commented 2 years ago

https://github.com/lucidrains/DALLE2-pytorch/commit/1ec4dbe64f5d462c5607465b1470d6b3c9fef311 that's weird, it seems to suggest it is receiving a text mask of 0 elements

rom1504 commented 2 years ago
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2362, in p_losses
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2362, in p_losses
        model_output = unet(model_output = unet(

  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    model_output = unet(
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    model_output = unet(
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)    
return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
    return forward_call(*input, **kwargs)    
return forward_call(*input, **kwargs)    
    return forward_call(*input, **kwargs)  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward

  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
        assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'

    assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
    AssertionErrorassert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'AssertionError:     
text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1]): 
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'AssertionError
text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1]): 
text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])AssertionError
: text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])
AssertionError: text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])
        assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'

AssertionError: AssertionErrortext_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1]): 
text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])
    assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
AssertionError: text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])
Veldrovive commented 2 years ago

@rom1504 so i'm looking through the train_decoder.py and i don't think the text mask is passed in anywhere

I was also looking at that. A mask isn't passed in because a mask isn't saved with the embeddings.

But in any case that should not be an issue in this case because the mask is being generated by clip here. https://github.com/lucidrains/DALLE2-pytorch/blob/1ec4dbe64f5d462c5607465b1470d6b3c9fef311/dalle2_pytorch/dalle2_pytorch.py#L2524

rom1504 commented 2 years ago

so that's here https://github.com/lucidrains/DALLE2-pytorch/blob/1f1557c614700cf38e6666a984083e45865ae8e8/dalle2_pytorch/dalle2_pytorch.py#L217

does that look right ?

lucidrains commented 2 years ago

@Veldrovive i think going forward i'm just going to remove text masks altogether

it is too confusing for non-researchers to figure out haha (and it is confusing even for researchers too šŸ˜† )

lucidrains commented 2 years ago

https://github.com/lucidrains/DALLE2-pytorch/commit/bb3ff0ac679f4b7540a5611382af5dbaba1f5139

lucidrains commented 2 years ago

let me remove text masks altogether now and do a minor version update

lucidrains commented 2 years ago

ok done, no masks no worries šŸ˜„

rom1504 commented 2 years ago
  File "/fsx/dalle2/upsampler/train_decoder.py", line 362, in train
    loss = trainer.forward(img, **forward_params, unet_number=unet)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/trainer.py", line 107, in inner
    out = fn(model, *args, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/trainer.py", line 721, in forward
    loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/accelerate/utils/operations.py", line 487, in __call__
    return convert_to_fp32(self.model_forward(*args, **kwargs))
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/amp/autocast_mode.py", line 12, in decorate_autocast
    return func(*args, **kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
    output = self._run_ddp_forward(*inputs, **kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
    return module_to_run(*inputs[0], **kwargs[0])
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2537, in forward
    losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_vari
ance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2356, in p_losses
    model_output = unet(
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1833, in forward
    assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_e
ncodings.shape}'
AssertionError: text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1]). text encoding is of shape torch.Size([0, 77, 768])
rom1504 commented 2 years ago

haha sorry too tired to debug myself today, I'll take a longer look tomorrow

rom1504 commented 2 years ago

feels like some confusion in dimensions happening somewhere

rom1504 commented 2 years ago

256 looks like the image size, 60 the batch size, 77 the clip token length and 768 the clip embedding dim

lucidrains commented 2 years ago

ahh, for some reason, the text encoding being passed in has dimensions

text encoding is of shape torch.Size([0, 77, 768])

batch size of 0

lucidrains commented 2 years ago

let me add a few more asserts

lucidrains commented 2 years ago

so the error has to do with the text encodings being passed in from somewhere

lucidrains commented 2 years ago

well, regardless, getting rid of the masks was on my agenda šŸ˜†

lucidrains commented 2 years ago

yea, besides adding the asserts, i can't really help with this one unless if it is reproducible

rom1504 commented 2 years ago
    model_output = unet(
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    model_output = unet(
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1815, in forward
    return self.main(*args, **kwargs)
  File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 1055, in main
    return forward_call(*input, **kwargs)
  File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1815, in forward
    assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {t
ext_encodings.shape} - required batch size is {batch_size}'
AssertionError: the text encodings being passed into the unet does not have the proper batch size - text encoding shape torch.Size([0, 77, 768]) - required ba
tch size is 60

yeah I can probably add a test to reproduce it

lucidrains commented 2 years ago

@rom1504 all of the scripts i have in the readme is running fine

in what context is the decoder being called?

Veldrovive commented 2 years ago

I can also debug this if @rom1504 posts his config. It probable that this is a dataloading or wrangling issue.

rom1504 commented 2 years ago

I'm running train decoder with this config https://wandb.ai/rom1504/dalle2_train_decoder/runs/sm4jhcai/files/decoder_config.json

and I think I might have figured it out, I had "cond_on_text_encodings": false for the first unet (that is not getting trained) trying with that at true

Veldrovive commented 2 years ago

Ah, if you are on main I don't think training more than one unet will work.

rom1504 commented 2 years ago

nope still getting AssertionError: the text encodings being passed into the unet does not have the proper batch size - text encoding shape torch.Size([0, 77, 768]) - required batch size is 60

rom1504 commented 2 years ago

I'm trying to train only one unet, the upsampler

rom1504 commented 2 years ago
    "unet_training_mask": [false, true]

this should do it right ?

Veldrovive commented 2 years ago

I'm not even sure if that feature is on main yet. I would have to check.

Veldrovive commented 2 years ago

It might get ignored or misused by the script at the moment.

Veldrovive commented 2 years ago

And sampling will 100% not work because I had to make some pretty substantial changes on the upsampling branch to get that working.

Veldrovive commented 2 years ago

I can bring that fork up with main if you want to try some fixes for upsampling.

rom1504 commented 2 years ago

yeah that'd be great

nousr commented 2 years ago

@lucidrains I was getting NaN's on the prior's .forward when I rebased to main.

Reverting to the following commit's lines fixed the issue for me, just thought i'd throw it here

https://github.com/lucidrains/DALLE2-pytorch/blob/3dae43fa0ef2c5a29df9c275b1ebeb1a65c3bac5/dalle2_pytorch/dalle2_pytorch.py#L310-L318

https://github.com/lucidrains/DALLE2-pytorch/blob/3dae43fa0ef2c5a29df9c275b1ebeb1a65c3bac5/dalle2_pytorch/dalle2_pytorch.py#L866-L867

https://github.com/lucidrains/DALLE2-pytorch/blob/3dae43fa0ef2c5a29df9c275b1ebeb1a65c3bac5/dalle2_pytorch/dalle2_pytorch.py#L1177-L1178

lucidrains commented 2 years ago

@nousr ohh that's odd, i thought you weren't even using mask in your code

nousr commented 2 years ago

@lucidrains I don't specify one, but it was getting calculated on the fly by the embed_text if i understand correctly. I'm still trying to isolate exactly what was causing it--but I got it working again, so I just wanted to document it before I got too deep in the sauce.

(this may be a tangential issue)

lucidrains commented 2 years ago

@nousr oh ok! yea i just brought back the old behavior as the default https://github.com/lucidrains/DALLE2-pytorch/commit/775abc4df655c2945987274a0ab5a19a9cc4d45a you can switch it off for the "correct" behavior

although in the DL field, what empirically works is "correct" lol

lucidrains commented 2 years ago

@nousr so you mean the main branch is working? or reverting it to the old commit is working?

lucidrains commented 2 years ago

@nousr ohh, so the clip is instantiated in the training code, and embedding the text at runtime? i thought everything was being pre-encoded

nousr commented 2 years ago

@nousr so you mean the main branch is working? or reverting it to the old commit is working?

old commit

@nousr ohh, so the clip is instantiated in the training code, and embedding the text at runtime? i thought everything was being pre-encoded

correct. we initially used pre-encoded stuff, but eventually went to feeding pre-computed image embeddings while embedding text at runtime

lucidrains commented 2 years ago

@nousr i think we should just follow the new masking rule, which is that any text encoding must be padded with zeros during the dataloading collation

lucidrains commented 2 years ago

@nousr ohh ok, that sounds fine, if you are still using clip to embed the text

i'm masking out any padding here https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L324

nousr commented 2 years ago

@nousr i think we should just follow the new masking rule, which is that any text encoding must be padded with zeros during the dataloading collation

sure, I don't really see any problem with that. I'm more just trying to get to the bottom of what's causing the NaN values...

I'm a few minutes away from just running this on my local machine and stepping through with a debugger to see where the tensor is becoming NaN (subsequently heating the apartment)

lucidrains commented 2 years ago

@nousr yea, it could just be usual transformers instability

lucidrains commented 2 years ago

@nousr decided to remove masking altogether from the causal transformer too, using a strategy from the original dalle1 paper

nousr commented 2 years ago

@nousr yea, it could just be usual transformers instability

maybe, its more like all i get is NaN loss though--its not just a few

lucidrains commented 2 years ago

@nousr hmm, maybe it has to do with the fact that there are padding tokens detected at the very beginning of the text encodings (which should not be there but who knows). that would result in a row in the attention matrix that is completely masked out

i think the latest change in 0.23.0 should address that

lucidrains commented 2 years ago

@nousr threw in yet another stability measure https://github.com/lucidrains/DALLE2-pytorch/commit/349aaca56fe66a6f0fc6720c91ee5d7fa1e36f93

worst comes to worst, we can always rollback

thanks for letting me know about the dreaded NaNs!

nousr commented 2 years ago

awesome! ill give it a shot tomorrow, iā€™m gonna sign off for the afternoon.

On Jul 12, 2022, at 5:51 PM, Phil Wang @.***> wrote:

ļ»æ @nousr threw in yet another stability measure 349aaca

worst comes to worst, we can always rollback

thanks for letting me know about the dreaded NaNs!

ā€” Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.

rom1504 commented 2 years ago

this is fixed now