deep-floyd / IF

Other
7.63k stars 495 forks source link

Very limited number of tokens for prompting (77), why? #108

Open phalexo opened 1 year ago

phalexo commented 1 year ago

Token indices sequence length is longer than the specified maximum sequence length for this model (102 > 77). Running this sequence through the model will result in indexing errors The following part of your input was truncated because CLIP can only handle sequences up to 77 tokens: [': 1 5 0 cfg scale : 7. 5 sampler : k euler :: very detailed, rim - light']

tildebyte commented 1 year ago

AFAICT, 77 is baked into T5, or at least HuggingFace's T5.

77 is hard-coded in 't5.py'

Bumping that number up (by manually editing 't5.py'; I tried '144') simply results in

File /usr/local/lib/python3.10/dist-packages/deepfloyd_if/modules/base.py:166, in IFBaseModule.embeddings_to_image(self, t5_embs, low_res, style_t5_embs, positive_t5_embs, negative_t5_embs, batch_repeat, dynamic_thresholding_p, sample_loop, sample_timestep_respacing, dynamic_thresholding_c, guidance_scale, aug_level, positive_mixer, blur_sigma, img_size, img_scale, aspect_ratio, progress, seed, sample_fn, support_noise, support_noise_less_qsample_steps, inpainting_mask, **kwargs)
    161 else:
    162     list_text_emb.append(
    163         self.zero_emb.unsqueeze(0).repeat(batch_size, 1, 1).to(self.device, dtype=self.model.dtype))
    165 model_kwargs = dict(
--> 166     text_emb=torch.cat(list_text_emb, dim=0).to(self.device, dtype=self.model.dtype),
    167     timestep_text_emb=timestep_text_emb,
    168     use_cache=True,
    169 )
    170 if low_res is not None:
    171     if blur_sigma is not None:

RuntimeError: Sizes of tensors must match except in dimension 0. Expected size 144 but got size 77 for tensor number 1 in the list.

Tracing the code above this shows that, in this instance (because of the prompt and parameters I'm using), list_text_emb is just the prompt, wrapped in a list. Ipso facto, something is messing with parameters somehow. We set the tensor size to 144 via max_length, but the call to torch.cat() received a tensor of size 77.

OTOH, I'm probably missing something obvious 😁