lucidrains / imagen-pytorch

Implementation of Imagen, Google's Text-to-Image Neural Network, in Pytorch
MIT License
7.99k stars 757 forks source link

CUDA out of memory with `max_batch_size=1` using unconditional image-to-image #64

Closed sgbaird closed 2 years ago

sgbaird commented 2 years ago

Based on the README usage instructions, except with max_batch_size=1 running on Windows:

import torch
from imagen_pytorch import Imagen, ImagenTrainer, SRUnet256, Unet

# unets for unconditional imagen

unet1 = Unet(
    dim=32,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=3,
    layer_attns=(False, True, True),
    layer_cross_attns=(False, True, True),
    use_linear_attn=True,
)

unet2 = SRUnet256(
    dim=32,
    dim_mults=(1, 2, 4),
    num_resnet_blocks=(2, 4, 8),
    layer_attns=(False, False, True),
    layer_cross_attns=(False, False, True),
)

# imagen, which contains the unets above (base unet and super resoluting ones)

imagen = Imagen(
    condition_on_text=False,  # this must be set to False for unconditional Imagen
    unets=(unet1, unet2),
    image_sizes=(64, 128),
    timesteps=1000,
)

trainer = ImagenTrainer(imagen).cuda()

# now get a ton of images and feed it through the Imagen trainer

training_images = torch.randn(4, 3, 256, 256).cuda()

# train each unet in concert, or separately (recommended) to completion

for u in (1, 2):
    loss = trainer(training_images, unet_number=u, max_batch_size=1)
    trainer.update(unet_number=u)

# do the above for many many many many steps
# now you can sample images unconditionally from the cascading unet(s)

images = trainer.sample(batch_size=16)  # (16, 3, 128, 128)

The OOM error occurs during the SRUnet (set a breakpoint and checked)

Exception has occurred: RuntimeError       (note: full exception trace is shown but execution is paused at: _run_module_as_main)
CUDA out of memory. Tried to allocate 20.00 MiB (GPU 0; 6.00 GiB total capacity; 4.26 GiB already allocated; 0 bytes free; 4.31 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation.  See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\autograd\__init__.py", line 173, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\_tensor.py", line 363, in backward
    torch.autograd.backward(self, gradient, retain_graph, create_graph, inputs=inputs)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\imagen_pytorch\trainer.py", line 508, in forward
    self.scale(loss, unet_number = unet_number).backward()
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\imagen_pytorch\trainer.py", line 98, in inner
    out = fn(model, *args, **kwargs)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\site-packages\torch\nn\modules\module.py", line 1110, in _call_impl
    return forward_call(*input, **kwargs)
  File "C:\Users\sterg\Documents\GitHub\sparks-baird\xtal2png\scripts\imagen_pytorch_example.py", line 41, in <module>
    loss = trainer(training_images, unet_number=u, max_batch_size=1)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 97, in _run_module_code
    _run_code(code, mod_globals, init_globals,
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 268, in run_path
    return _run_module_code(code, init_globals, run_name,
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "C:\Users\sterg\miniconda3\envs\xtal2png-docs\Lib\runpy.py", line 197, in _run_module_as_main (Current frame)
    return _run_code(code, main_globals, None,

I'm using an NVIDIA GeForce RTX 2060:

Type Value
GPU Architecture Turing
RTX-OPS 37T
Boost Clock 1680 MHz
Frame Buffer 6GB GDDR6
Memory Speed 14 Gbps

See also #12

sgbaird commented 2 years ago

Works fine if I drop the SRUnet256 and replace it with the regular Unet. Still trying to play around with parameters for the SRUnet256 to see if I can get it to under 6 GB at least for prototyping.

sgbaird commented 2 years ago

@lucidrains so SRUnet256 overrides many of the parameters? https://github.com/lucidrains/imagen-pytorch/blob/ccf848cfd1f8306e654112d61e95376180a67a90/imagen_pytorch/imagen_pytorch.py#L1419-L1431

Seems to be the case after looking at values via breakpoint.

unet2 = SRUnet256(
    dim=4,
    dim_mults=(1, 2),
    num_resnet_blocks=(2, 2),
    layer_attns=(False, False),
    layer_cross_attns=(False, False),
    attn_heads=1,
)
print(unet2) ``` SRUnet256( (init_conv): CrossEmbedLayer( (convs): ModuleList( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Conv2d(3, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3)) (2): Conv2d(3, 32, kernel_size=(15, 15), stride=(1, 1), padding=(7, 7)) ) ) (to_time_hiddens): Sequential( (0): LearnedSinusoidalPosEmb() (1): Linear(in_features=17, out_features=512, bias=True) (2): SiLU() ) (to_time_cond): Sequential( (0): Linear(in_features=512, out_features=512, bias=True) ) (to_time_tokens): Sequential( (0): Linear(in_features=512, out_features=256, bias=True) (1): Rearrange('b (r d) -> b r d', r=2) ) (norm_cond): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (norm_mid_cond): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (text_to_cond): Linear(in_features=768, out_features=128, bias=True) (attn_pool): PerceiverResampler( (pos_emb): Embedding(512, 128) (to_latents_from_mean_pooled_seq): Sequential( (0): LayerNorm() (1): Linear(in_features=128, out_features=512, bias=True) (2): Rearrange('b (n d) -> b n d', n=4) ) (layers): ModuleList( (0): ModuleList( (0): PerceiverAttention( (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (norm_latents): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (to_q): Linear(in_features=128, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=128, bias=False) (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) ) (1): Sequential( (0): LayerNorm() (1): Linear(in_features=128, out_features=512, bias=False) (2): GELU() (3): LayerNorm() (4): Linear(in_features=512, out_features=128, bias=False) ) ) (1): ModuleList( (0): PerceiverAttention( (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (norm_latents): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (to_q): Linear(in_features=128, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=128, bias=False) (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) ) (1): Sequential( (0): LayerNorm() (1): Linear(in_features=128, out_features=512, bias=False) (2): GELU() (3): LayerNorm() (4): Linear(in_features=512, out_features=128, bias=False) ) ) ) ) (to_text_non_attn_cond): Sequential( (0): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (1): Linear(in_features=128, out_features=512, bias=True) (2): SiLU() (3): Linear(in_features=512, out_features=512, bias=True) ) (downs): ModuleList( (0): ModuleList( (0): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=256, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (2): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (3): Identity() (4): None ) (1): ModuleList( (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=512, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (2): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (3): Identity() (4): None ) (2): ModuleList( (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=1024, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (2): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (4): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (5): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (6): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (7): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (3): Identity() (4): None ) (3): ModuleList( (0): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=2048, bias=True) ) (cross_attn): EinopsToAndFrom( (fn): CrossAttention( (norm): LayerNorm() (norm_context): Identity() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (2): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (4): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (5): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (6): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (7): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (3): TransformerBlock( (attn): EinopsToAndFrom( (fn): Attention( (norm): LayerNorm() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=1024, out_features=128, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) (ff): Sequential( (0): ChanLayerNorm() (1): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (2): GELU() (3): ChanLayerNorm() (4): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) (4): None ) ) (ups): ModuleList( (0): ModuleList( (0): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=1024, bias=True) ) (cross_attn): EinopsToAndFrom( (fn): CrossAttention( (norm): LayerNorm() (norm_context): Identity() (to_q): Linear(in_features=512, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=512, bias=False) (1): LayerNorm() ) ) ) (block1): Block( (groupnorm): GroupNorm(8, 2048, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(2048, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Conv2d(2048, 512, kernel_size=(1, 1), stride=(1, 1)) ) (1): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (4): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (5): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (6): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (7): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (2): TransformerBlock( (attn): EinopsToAndFrom( (fn): Attention( (norm): LayerNorm() (to_q): Linear(in_features=512, out_features=512, bias=False) (to_kv): Linear(in_features=512, out_features=128, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=512, bias=False) (1): LayerNorm() ) ) ) (ff): Sequential( (0): ChanLayerNorm() (1): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) (2): GELU() (3): ChanLayerNorm() (4): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) (3): ConvTranspose2d(512, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) ) (1): ModuleList( (0): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=512, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Conv2d(1024, 256, kernel_size=(1, 1), stride=(1, 1)) ) (1): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (4): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (5): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (6): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (7): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (2): Identity() (3): ConvTranspose2d(256, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) ) (2): ModuleList( (0): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=256, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Conv2d(512, 128, kernel_size=(1, 1), stride=(1, 1)) ) (1): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (2): Identity() (3): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) ) (3): ModuleList( (0): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=256, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) ) (1): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (2): Identity() (3): ConvTranspose2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) ) ) (mid_block1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=2048, bias=True) ) (cross_attn): EinopsToAndFrom( (fn): CrossAttention( (norm): LayerNorm() (norm_context): Identity() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (mid_attn): EinopsToAndFrom( (fn): Residual( (fn): Attention( (norm): LayerNorm() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=1024, out_features=128, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) ) (mid_block2): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=2048, bias=True) ) (cross_attn): EinopsToAndFrom( (fn): CrossAttention( (norm): LayerNorm() (norm_context): Identity() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (final_conv): Sequential( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (1): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1)) ) ) special variables: function variables: T_destination: ~T_destination attn_pool: PerceiverResampler( (pos_emb): Embedding(512, 128) (to_latents_from_mean_pooled_seq): Sequential( (0): LayerNorm() (1): Linear(in_features=128, out_features=512, bias=True) (2): Rearrange('b (n d) -> b n d', n=4) ) (layers): ModuleList( (0): ModuleList( (0): PerceiverAttention( (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (norm_latents): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (to_q): Linear(in_features=128, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=128, bias=False) (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) ) (1): Sequential( (0): LayerNorm() (1): Linear(in_features=128, out_features=512, bias=False) (2): GELU() (3): LayerNorm() (4): Linear(in_features=512, out_features=128, bias=False) ) ) (1): ModuleList( (0): PerceiverAttention( (norm): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (norm_latents): LayerNorm((128,), eps=1e-05, elementwise_affine=True) (to_q): Linear(in_features=128, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=128, bias=False) (1): LayerNorm((128,), eps=1e-05, elementwise_affine=True) ) ) (1): Sequential( (0): LayerNorm() (1): Linear(in_features=128, out_features=512, bias=False) (2): GELU() (3): LayerNorm() (4): Linear(in_features=512, out_features=128, bias=False) ) ) ) ) channels: 3 channels_out: 3 cond_on_text: True downs: ModuleList( (0): ModuleList( (0): Conv2d(128, 128, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=256, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (2): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(128, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(128, 64, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(64, 128, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (3): Identity() (4): None ) (1): ModuleList( (0): Conv2d(128, 256, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=512, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (2): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 256, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(256, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(256, 128, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(128, 256, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (3): Identity() (4): None ) (2): ModuleList( (0): Conv2d(256, 512, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=1024, bias=True) ) (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (2): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (4): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (5): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (6): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (7): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 512, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(512, 512, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(512, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(512, 256, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(256, 512, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (3): Identity() (4): None ) (3): ModuleList( (0): Conv2d(512, 1024, kernel_size=(4, 4), stride=(2, 2), padding=(1, 1)) (1): ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=2048, bias=True) ) (cross_attn): EinopsToAndFrom( (fn): CrossAttention( (norm): LayerNorm() (norm_context): Identity() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (2): ModuleList( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (1): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (2): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (3): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (4): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (5): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (6): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) (7): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): GlobalContext( (to_k): Conv2d(1024, 1, kernel_size=(1, 1), stride=(1, 1)) (net): Sequential( (0): Conv2d(1024, 512, kernel_size=(1, 1), stride=(1, 1)) (1): SiLU() (2): Conv2d(512, 1024, kernel_size=(1, 1), stride=(1, 1)) (3): Sigmoid() ) ) (res_conv): Scale() ) ) (3): TransformerBlock( (attn): EinopsToAndFrom( (fn): Attention( (norm): LayerNorm() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=1024, out_features=128, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) (ff): Sequential( (0): ChanLayerNorm() (1): Conv2d(1024, 2048, kernel_size=(1, 1), stride=(1, 1), bias=False) (2): GELU() (3): ChanLayerNorm() (4): Conv2d(2048, 1024, kernel_size=(1, 1), stride=(1, 1), bias=False) ) ) (4): None ) ) dump_patches: False final_conv: Sequential( (0): ResnetBlock( (block1): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 128, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) (1): Conv2d(128, 3, kernel_size=(1, 1), stride=(1, 1)) ) init_conv: CrossEmbedLayer( (convs): ModuleList( (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) (1): Conv2d(3, 32, kernel_size=(7, 7), stride=(1, 1), padding=(3, 3)) (2): Conv2d(3, 32, kernel_size=(15, 15), stride=(1, 1), padding=(7, 7)) ) ) init_conv_to_final_conv_residual: False learned_sinu_pos_emb: True lowres_cond: False max_text_len: 256 mid_attn: EinopsToAndFrom( (fn): Residual( (fn): Attention( (norm): LayerNorm() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=1024, out_features=128, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) ) mid_block1: ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=2048, bias=True) ) (cross_attn): EinopsToAndFrom( (fn): CrossAttention( (norm): LayerNorm() (norm_context): Identity() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) mid_block2: ResnetBlock( (time_mlp): Sequential( (0): SiLU() (1): Linear(in_features=512, out_features=2048, bias=True) ) (cross_attn): EinopsToAndFrom( (fn): CrossAttention( (norm): LayerNorm() (norm_context): Identity() (to_q): Linear(in_features=1024, out_features=512, bias=False) (to_kv): Linear(in_features=128, out_features=1024, bias=False) (to_out): Sequential( (0): Linear(in_features=512, out_features=1024, bias=False) (1): LayerNorm() ) ) ) (block1): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (block2): Block( (groupnorm): GroupNorm(8, 1024, eps=1e-05, affine=True) (activation): SiLU() (project): Conv2d(1024, 1024, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)) ) (gca): Identity() (res_conv): Scale() ) TRUNCATED
samedii commented 2 years ago

I don't think you should expect to be able to run inference on 6GB, much less train something

sgbaird commented 2 years ago

@samedii

Still trying to play around with parameters for the SRUnet256 to see if I can get it to under 6 GB at least for prototyping.

Will probably do production runs via slurm submissions using my uni's hpc which will give me much more than 6 GB.

lupinetine commented 2 years ago

I'm hitting the same issue with 12GB, and have also run into this issue on an 40GB A100 that I used to verify. I have trained and inferred successfully on 8-16GB on versions up until 0.7. I jumped from 0.3 to 0.7 so I'll have to backtrack to find the last time this was working successfully on my equipment.

sgbaird commented 2 years ago

@lupinetine this is my first time using the library. If you figure out where the change occurred, would love to know! cc @lucidrains

sgbaird commented 2 years ago

Also, I realize max_batch_size=1 is a silly choice, was just trying to see if I could get anything to run without the OOM error https://github.com/lucidrains/imagen-pytorch/issues/24#issuecomment-1142306411

Might be fixed with some of the new releases.

lucidrains commented 2 years ago

should be better now that only one unet is loaded into memory at any given time

if it still OOMs, you should buy a better graphics card