Closed rom1504 closed 2 years ago
@rom1504 oh oops, might be me :cry: https://github.com/lucidrains/DALLE2-pytorch/commit/e0835acca97ba75b226d6eb2023200a293929fbb may or may not fix it haha
Traceback (most recent call last):
File "/fsx/dalle2/upsampler/train_decoder.py", line 610, in <module>
main()
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 1130, in __call__
return self.main(*args, **kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 1055, in main
rv = self.invoke(ctx)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 1404, in invoke
return ctx.invoke(self.callback, **ctx.params)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 760, in invoke
return __callback(*args, **kwargs)
File "/fsx/dalle2/upsampler/train_decoder.py", line 607, in main
initialize_training(config, config_path=config_file_path)
File "/fsx/dalle2/upsampler/train_decoder.py", line 593, in initialize_training
train(dataloaders, decoder, accelerator,
File "/fsx/dalle2/upsampler/train_decoder.py", line 362, in train
loss = trainer.forward(img, **forward_params, unet_number=unet)
File "/fsx/dalle2/upsampler/dalle2_pytorch/trainer.py", line 107, in inner
out = fn(model, *args, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/trainer.py", line 721, in forward
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/accelerate/utils/operations.py", line 487, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/amp/autocast_mode.py", line 12, in decorate_autocast
return func(*args, **kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0])
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2543, in forward
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, text_mask = text_mask, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_variance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2358, in p_losses
model_output = unet(
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1835, in forward
text_keep_mask = text_mask & text_keep_mask
RuntimeError: The size of tensor a (0) must match the size of tensor b (60) at non-singleton dimension 0
didn't fix hmm
https://github.com/lucidrains/DALLE2-pytorch/commit/1ec4dbe64f5d462c5607465b1470d6b3c9fef311 that's weird, it seems to suggest it is receiving a text mask of 0 elements
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2362, in p_losses
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2362, in p_losses
model_output = unet(model_output = unet(
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
model_output = unet(
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
model_output = unet(
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
return forward_call(*input, **kwargs)
return forward_call(*input, **kwargs)
return forward_call(*input, **kwargs) File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1838, in forward
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
AssertionErrorassert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'AssertionError:
text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1]):
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'AssertionError
text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1]):
text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])AssertionError
: text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])
AssertionError: text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
AssertionError: AssertionErrortext_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1]):
text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}'
AssertionError: text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1])
@rom1504 so i'm looking through the
train_decoder.py
and i don't think the text mask is passed in anywhere
I was also looking at that. A mask isn't passed in because a mask isn't saved with the embeddings.
But in any case that should not be an issue in this case because the mask is being generated by clip here. https://github.com/lucidrains/DALLE2-pytorch/blob/1ec4dbe64f5d462c5607465b1470d6b3c9fef311/dalle2_pytorch/dalle2_pytorch.py#L2524
so that's here https://github.com/lucidrains/DALLE2-pytorch/blob/1f1557c614700cf38e6666a984083e45865ae8e8/dalle2_pytorch/dalle2_pytorch.py#L217
does that look right ?
@Veldrovive i think going forward i'm just going to remove text masks altogether
it is too confusing for non-researchers to figure out haha (and it is confusing even for researchers too š )
let me remove text masks altogether now and do a minor version update
ok done, no masks no worries š
File "/fsx/dalle2/upsampler/train_decoder.py", line 362, in train
loss = trainer.forward(img, **forward_params, unet_number=unet)
File "/fsx/dalle2/upsampler/dalle2_pytorch/trainer.py", line 107, in inner
out = fn(model, *args, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/trainer.py", line 721, in forward
loss = self.decoder(*chunked_args, unet_number = unet_number, **chunked_kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/accelerate/utils/operations.py", line 487, in __call__
return convert_to_fp32(self.model_forward(*args, **kwargs))
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/amp/autocast_mode.py", line 12, in decorate_autocast
return func(*args, **kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1008, in forward
output = self._run_ddp_forward(*inputs, **kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/parallel/distributed.py", line 969, in _run_ddp_forward
return module_to_run(*inputs[0], **kwargs[0])
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2537, in forward
losses = self.p_losses(unet, image, times, image_embed = image_embed, text_encodings = text_encodings, lowres_cond_img = lowres_cond_img, predict_x_start = predict_x_start, learned_vari
ance = learned_variance, is_latent_diffusion = is_latent_diffusion, noise_scheduler = noise_scheduler)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 2356, in p_losses
model_output = unet(
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1833, in forward
assert text_mask.shape[0] == text_keep_mask.shape[0], f'text_mask has shape of {text_mask.shape} while text_keep_mask has shape {text_keep_mask.shape}. text encoding is of shape {text_e
ncodings.shape}'
AssertionError: text_mask has shape of torch.Size([0, 256, 1]) while text_keep_mask has shape torch.Size([60, 1, 1]). text encoding is of shape torch.Size([0, 77, 768])
haha sorry too tired to debug myself today, I'll take a longer look tomorrow
feels like some confusion in dimensions happening somewhere
256 looks like the image size, 60 the batch size, 77 the clip token length and 768 the clip embedding dim
ahh, for some reason, the text encoding being passed in has dimensions
text encoding is of shape torch.Size([0, 77, 768])
batch size of 0
let me add a few more asserts
so the error has to do with the text encodings being passed in from somewhere
well, regardless, getting rid of the masks was on my agenda š
yea, besides adding the asserts, i can't really help with this one unless if it is reproducible
model_output = unet(
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
model_output = unet(
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1815, in forward
return self.main(*args, **kwargs)
File "/fsx/dalle2/.upsampler/lib64/python3.8/site-packages/click/core.py", line 1055, in main
return forward_call(*input, **kwargs)
File "/fsx/dalle2/upsampler/dalle2_pytorch/dalle2_pytorch.py", line 1815, in forward
assert text_encodings.shape[0] == batch_size, f'the text encodings being passed into the unet does not have the proper batch size - text encoding shape {t
ext_encodings.shape} - required batch size is {batch_size}'
AssertionError: the text encodings being passed into the unet does not have the proper batch size - text encoding shape torch.Size([0, 77, 768]) - required ba
tch size is 60
yeah I can probably add a test to reproduce it
@rom1504 all of the scripts i have in the readme is running fine
in what context is the decoder being called?
I can also debug this if @rom1504 posts his config. It probable that this is a dataloading or wrangling issue.
I'm running train decoder with this config https://wandb.ai/rom1504/dalle2_train_decoder/runs/sm4jhcai/files/decoder_config.json
and I think I might have figured it out, I had "cond_on_text_encodings": false for the first unet (that is not getting trained) trying with that at true
Ah, if you are on main I don't think training more than one unet will work.
nope still getting AssertionError: the text encodings being passed into the unet does not have the proper batch size - text encoding shape torch.Size([0, 77, 768]) - required batch size is 60
I'm trying to train only one unet, the upsampler
"unet_training_mask": [false, true]
this should do it right ?
I'm not even sure if that feature is on main yet. I would have to check.
It might get ignored or misused by the script at the moment.
And sampling will 100% not work because I had to make some pretty substantial changes on the upsampling branch to get that working.
I can bring that fork up with main if you want to try some fixes for upsampling.
yeah that'd be great
@lucidrains I was getting NaN's on the prior's .forward
when I rebased to main.
Reverting to the following commit's lines fixed the issue for me, just thought i'd throw it here
@nousr ohh that's odd, i thought you weren't even using mask in your code
@lucidrains I don't specify one, but it was getting calculated on the fly by the embed_text
if i understand correctly. I'm still trying to isolate exactly what was causing it--but I got it working again, so I just wanted to document it before I got too deep in the sauce.
(this may be a tangential issue)
@nousr oh ok! yea i just brought back the old behavior as the default https://github.com/lucidrains/DALLE2-pytorch/commit/775abc4df655c2945987274a0ab5a19a9cc4d45a you can switch it off for the "correct" behavior
although in the DL field, what empirically works is "correct" lol
@nousr so you mean the main branch is working? or reverting it to the old commit is working?
@nousr ohh, so the clip is instantiated in the training code, and embedding the text at runtime? i thought everything was being pre-encoded
@nousr so you mean the main branch is working? or reverting it to the old commit is working?
old commit
@nousr ohh, so the clip is instantiated in the training code, and embedding the text at runtime? i thought everything was being pre-encoded
correct. we initially used pre-encoded stuff, but eventually went to feeding pre-computed image embeddings while embedding text at runtime
@nousr i think we should just follow the new masking rule, which is that any text encoding must be padded with zeros during the dataloading collation
@nousr ohh ok, that sounds fine, if you are still using clip to embed the text
i'm masking out any padding here https://github.com/lucidrains/DALLE2-pytorch/blob/main/dalle2_pytorch/dalle2_pytorch.py#L324
@nousr i think we should just follow the new masking rule, which is that any text encoding must be padded with zeros during the dataloading collation
sure, I don't really see any problem with that. I'm more just trying to get to the bottom of what's causing the NaN values...
I'm a few minutes away from just running this on my local machine and stepping through with a debugger to see where the tensor is becoming NaN (subsequently heating the apartment)
@nousr yea, it could just be usual transformers instability
@nousr decided to remove masking altogether from the causal transformer too, using a strategy from the original dalle1 paper
@nousr yea, it could just be usual transformers instability
maybe, its more like all i get is NaN loss though--its not just a few
@nousr hmm, maybe it has to do with the fact that there are padding tokens detected at the very beginning of the text encodings (which should not be there but who knows). that would result in a row in the attention matrix that is completely masked out
i think the latest change in 0.23.0 should address that
@nousr threw in yet another stability measure https://github.com/lucidrains/DALLE2-pytorch/commit/349aaca56fe66a6f0fc6720c91ee5d7fa1e36f93
worst comes to worst, we can always rollback
thanks for letting me know about the dreaded NaN
s!
awesome! ill give it a shot tomorrow, iām gonna sign off for the afternoon.
On Jul 12, 2022, at 5:51 PM, Phil Wang @.***> wrote:
ļ»æ @nousr threw in yet another stability measure 349aaca
worst comes to worst, we can always rollback
thanks for letting me know about the dreaded NaNs!
ā Reply to this email directly, view it on GitHub, or unsubscribe. You are receiving this because you were mentioned.
this is fixed now
Recently introduced but not sure which commit