THUDM / Inf-DiT

Official implementation of Inf-DiT: Upsampling Any-Resolution Image with Memory-Efficient Diffusion Transformer
Apache License 2.0
292 stars 12 forks source link

Input image size #16

Open elismasilva opened 1 week ago

elismasilva commented 1 week ago

My input image is need to be always 512x512 ? if i put image of size 1080x780 for eg. i got error shape.

yzy-thu commented 6 days ago

What error did you get? and how is the inference_type set?I think Input is not limited to 512

elismasilva commented 5 days ago

What error did you get? and how is the inference_type set?I think Input is not limited to 512 i am using this args:

"--input-type", "cli",
"--input-path",".\\test\\band.png",
"--inference_type","full",
"--block_batch", "4", 
"--experiment-name", "generate",
"--mode", "inference",
"--inference-batch-size", "1",
"--image-size", "512",
"--input-time", "adaln",
"--nogate",
"--no-crossmask",
"--bf16",
"--num-layers", "28",
"--vocab-size", "1",
"--hidden-size", "1280",
"--num-attention-heads", "16",
"--hidden-dropout", "0.",
"--attention-dropout", "0.",
"--in-channels", "6",
"--out-channels", "3",
"--cross-attn-hidden-size","640",
"--patch-size","4",
"--config-path","configs/text2image-sr.yaml",
"--max-sequence-length", "256",
"--layernorm-epsilon", "1e-6",
"--layernorm-order", "pre",
"--model-parallel-size", "1",
"--tokenizer-type", "fake",
"--random-position",
"--qk-ln",
"--out-dir", "samples",
"--network", "ckpt/mp_rank_00_model_states.pt",
"--round", "32",
"--init_noise",
"--image-condition",
"--vector-dim", "768",
"--re-position", 
"--cross-lr",
"--seed", "1",
"--infer_sr_scale", "2",

my image has 1080x780: Then ive got this error:

Exception has occurred: EinopsError
 Error while processing rearrange-reduction pattern "b (x l) (y w) h d -> b x y (l w) h d".
 Input tensor shape: torch.Size([1, 384, 528, 16, 80]). Additional info: {'l': 32, 'w': 32}.
 Shape mismatch, can't divide axis of length 528 in chunks of 32
einops.EinopsError: Shape mismatch, can't divide axis of length 528 in chunks of 32

During handling of the above exception, another exception occurred:

  File "F:\Projetos\Inf-DiT\dit\model.py", line 231, in transform
    x = rearrange(x, 'b (x l) (y w) h d -> b x y (l w) h d', l=block_size, w=block_size)
  File "F:\Projetos\Inf-DiT\dit\model.py", line 270, in attention_forward
    query_layer = transform(query_layer)
  File "F:\Projetos\Inf-DiT\dit\model.py", line 414, in layer_forward
    attention_output = layer.attention(attention_input, mask, do_concat=do_concat, **kwargs)
  File "F:\Projetos\Inf-DiT\dit\model.py", line 610, in model_forward
    return super().forward(*args, **kwargs)
  File "F:\Projetos\Inf-DiT\dit\model.py", line 744, in precond_forward
    output, *output_per_layers = self.model_forward(*args, hw=[h//self.patch_size, w//self.patch_size], rope_position_ids=rope_position_ids, lr_imgs=lr_imgs, **kwargs)
  File "F:\Projetos\Inf-DiT\dit\model.py", line 859, in <lambda>
    denoiser = lambda images, sigmas, rope_position_ids, cond, sample_step: self.precond_forward(images=images,
  File "F:\Projetos\Inf-DiT\dit\sampling\samplers.py", line 60, in denoise
    denoised = denoiser(images, sigmas, rope_position_ids, cond, sample_step)
  File "F:\Projetos\Inf-DiT\dit\sampling\samplers.py", line 103, in sampler_step
    denoised = self.denoise(x, denoiser, sigma_hat, cond, uc, rope_position_ids, sample_step)
  File "F:\Projetos\Inf-DiT\dit\sampling\samplers.py", line 209, in __call__
    images = self.sampler_step(
  File "F:\Projetos\Inf-DiT\dit\model.py", line 865, in sample
    samples = self.sampler(denoiser=denoiser, x=None, cond=cond, uc=uncond, num_steps=num_steps, rope_position_ids=rope_position_ids, init_noise=init_noise)
  File "F:\Projetos\Inf-DiT\generate_t2i_sr.py", line 199, in main
    samples = net.sample(shape=concat_lr_image.shape, images=concat_lr_image, lr_imgs=lr_image, dtype=concat_lr_image.dtype, device=device, init_noise=args.init_noise, do_concat=not args.no_concat)
  File "F:\Projetos\Inf-DiT\generate_t2i_sr.py", line 246, in <module>
    main(args)
einops.EinopsError:  Error while processing rearrange-reduction pattern "b (x l) (y w) h d -> b x y (l w) h d".
 Input tensor shape: torch.Size([1, 384, 528, 16, 80]). Additional info: {'l': 32, 'w': 32}.
 Shape mismatch, can't divide axis of length 528 in chunks of 32

if i resize image to 709x512 works, the width must be 512 if only heigh is 512 i got same error: i tried change param image_size to 780 but ive got this error:

> Exception has occurred: RuntimeError
> Error(s) in loading state_dict for DiffusionEngine:
>   size mismatch for mixins.adaln_layer.rope.freqs_cos: copying a param with shape torch.Size([1024, 1024, 80]) from checkpoint, the shape in current model is torch.Size([1560, 1560, 80]).
>   size mismatch for mixins.adaln_layer.rope.freqs_sin: copying a param with shape torch.Size([1024, 1024, 80]) from checkpoint, the shape in current model is torch.Size([1560, 1560, 80]).
>   File "F:\Projetos\Inf-DiT\generate_t2i_sr.py", line 152, in main
>     net.load_state_dict(data['module'], strict=False)
>   File "F:\Projetos\Inf-DiT\generate_t2i_sr.py", line 246, in <module>
>     main(args)
> RuntimeError: Error(s) in loading state_dict for DiffusionEngine:
>   size mismatch for mixins.adaln_layer.rope.freqs_cos: copying a param with shape torch.Size([1024, 1024, 80]) from checkpoint, the shape in current model is torch.Size([1560, 1560, 80]).
>   size mismatch for mixins.adaln_layer.rope.freqs_sin: copying a param with shape torch.Size([1024, 1024, 80]) from checkpoint, the shape in current model is torch.Size([1560, 1560, 80]).
elismasilva commented 5 days ago

band this is my test image

EH-FT commented 5 days ago

try to set round=64. If you use infer_sr_scale=2,the input size should be divided by 64. We will support any resolution input next version.

elismasilva commented 4 days ago

try to set round=64. If you use infer_sr_scale=2,the input size should be divided by 64. We will support any resolution input next version.

ok, this fixed it, but for a 1024x1024 image on a scale of 2 the process as a whole is using a lot of memory > 22GB.