andreas128 / RePaint

Official PyTorch Code and Models of "RePaint: Inpainting using Denoising Diffusion Probabilistic Models", CVPR 2022
1.94k stars 157 forks source link

Guided diffusion pretrain can't sample in the Repaint #22

Open xiongsua opened 1 year ago

xiongsua commented 1 year ago

I sample get mistake on the Repaint . When i train on the guided-diffusion and after use ema_pt to smaple. Because guided-diffusion code need to change? and what do i need change in the guided-diffusion code

drakiez13 commented 1 year ago

Same problem. Do you have any solution.

andreas128 commented 1 year ago

It seems that the model architecture has changed.

Try to use the model code from your training by copying it here.

Does that work?

zhangbaijin commented 1 year ago

Hi, did you sovled the problem?

zhangbaijin commented 1 year ago

I sample get mistake on the Repaint . When i train on the guided-diffusion and after use ema_pt to smaple. Because guided-diffusion code need to change? and what do i need change in the guided-diffusion code

xyz-xdx commented 1 year ago

@xiongsua Hi, did you sovled the problem?

fangtun commented 1 year ago

I meet the same question, the values in the checkpoint do not match the network, RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "i nput_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.5.1.norm.weight", "input_b locks.5.1.norm.bias", "input_blocks.5.1.qkv.weight", "input_blocks.5.1.qkv.bias", "input_blocks.5.1.proj_out.weight", "input_blocks.5.1.proj_out.bias", "input_blocks.6.1.norm.weight", "input_blocks.6.1.norm.bias", "input_blocks.6.1. qkv.weight", "input_blocks.6.1.qkv.bias", "input_blocks.6.1.proj_out.weight", "input_blocks.6.1.proj_out.bias", "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias" , "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_b locks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input _blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "i nput_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bia s", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers .0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.7.2.in_layers.0.weight", "output_blocks.7.2.in_layers.0.bias", "output_blocks.7.2.in_layers.2.weight", "output_blocks.7.2.in_la yers.2.bias", "output_blocks.7.2.emb_layers.1.weight", "output_blocks.7.2.emb_layers.1.bias", "output_blocks.7.2.out_layers.0.weight", "output_blocks.7.2.out_layers.0.bias", "output_blocks.7.2.out_layers.3.weight", "output_blocks.7. 2.out_layers.3.bias", "output_blocks.8.1.norm.weight", "output_blocks.8.1.norm.bias", "output_blocks.8.1.qkv.weight", "output_blocks.8.1.qkv.bias", "output_blocks.8.1.proj_out.weight", "output_blocks.8.1.proj_out.bias", "output_bloc ks.9.1.norm.weight", "output_blocks.9.1.norm.bias", "output_blocks.9.1.qkv.weight", "output_blocks.9.1.qkv.bias", "output_blocks.9.1.proj_out.weight", "output_blocks.9.1.proj_out.bias", "output_blocks.10.1.norm.weight", "output_bloc ks.10.1.norm.bias", "output_blocks.10.1.qkv.weight", "output_blocks.10.1.qkv.bias", "output_blocks.10.1.proj_out.weight", "output_blocks.10.1.proj_out.bias", "output_blocks.11.1.norm.weight", "outputblocks.11.1.norm.bias", "output blocks.11.1.qkv.weight", "output_blocks.11.1.qkv.bias", "output_blocks.11.1.proj_out.weight", "output_blocks.11.1.proj_out.bias", "output_blocks.11.2.in_layers.0.weight", "output_blocks.11.2.in_layers.0.bias", "output_blocks.11.2.in _layers.2.weight", "output_blocks.11.2.in_layers.2.bias", "output_blocks.11.2.emb_layers.1.weight", "output_blocks.11.2.emb_layers.1.bias", "output_blocks.11.2.out_layers.0.weight", "output_blocks.11.2.out_layers.0.bias", "output_bl ocks.11.2.out_layers.3.weight", "output_blocks.11.2.out_layers.3.bias". Unexpected key(s) in state_dict: "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "output_blocks. 3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.7.2.conv.weight", "output_blocks.7.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias". size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]). size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).

handsomecong001 commented 1 year ago

I meet the same question, the values in the checkpoint do not match the network, RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.4.0.in_layers.0.weight", "input_blocks.4.0.in_layers.0.bias", "input_blocks.4.0.in_layers.2.weight", "input_blocks.4.0.in_layers.2.bias", "input_blocks.4.0.emb_layers.1.weight", "i nput_blocks.4.0.emb_layers.1.bias", "input_blocks.4.0.out_layers.0.weight", "input_blocks.4.0.out_layers.0.bias", "input_blocks.4.0.out_layers.3.weight", "input_blocks.4.0.out_layers.3.bias", "input_blocks.5.1.norm.weight", "input_b locks.5.1.norm.bias", "input_blocks.5.1.qkv.weight", "input_blocks.5.1.qkv.bias", "input_blocks.5.1.proj_out.weight", "input_blocks.5.1.proj_out.bias", "input_blocks.6.1.norm.weight", "input_blocks.6.1.norm.bias", "input_blocks.6.1. qkv.weight", "input_blocks.6.1.qkv.bias", "input_blocks.6.1.proj_out.weight", "input_blocks.6.1.proj_out.bias", "input_blocks.7.1.norm.weight", "input_blocks.7.1.norm.bias", "input_blocks.7.1.qkv.weight", "input_blocks.7.1.qkv.bias" , "input_blocks.7.1.proj_out.weight", "input_blocks.7.1.proj_out.bias", "input_blocks.8.0.in_layers.0.weight", "input_blocks.8.0.in_layers.0.bias", "input_blocks.8.0.in_layers.2.weight", "input_blocks.8.0.in_layers.2.bias", "input_b locks.8.0.emb_layers.1.weight", "input_blocks.8.0.emb_layers.1.bias", "input_blocks.8.0.out_layers.0.weight", "input_blocks.8.0.out_layers.0.bias", "input_blocks.8.0.out_layers.3.weight", "input_blocks.8.0.out_layers.3.bias", "input _blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "i nput_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "output_blocks.3.2.in_layers.0.weight", "output_blocks.3.2.in_layers.0.bia s", "output_blocks.3.2.in_layers.2.weight", "output_blocks.3.2.in_layers.2.bias", "output_blocks.3.2.emb_layers.1.weight", "output_blocks.3.2.emb_layers.1.bias", "output_blocks.3.2.out_layers.0.weight", "output_blocks.3.2.out_layers .0.bias", "output_blocks.3.2.out_layers.3.weight", "output_blocks.3.2.out_layers.3.bias", "output_blocks.7.2.in_layers.0.weight", "output_blocks.7.2.in_layers.0.bias", "output_blocks.7.2.in_layers.2.weight", "output_blocks.7.2.in_la yers.2.bias", "output_blocks.7.2.emb_layers.1.weight", "output_blocks.7.2.emb_layers.1.bias", "output_blocks.7.2.out_layers.0.weight", "output_blocks.7.2.out_layers.0.bias", "output_blocks.7.2.out_layers.3.weight", "output_blocks.7. 2.out_layers.3.bias", "output_blocks.8.1.norm.weight", "output_blocks.8.1.norm.bias", "output_blocks.8.1.qkv.weight", "output_blocks.8.1.qkv.bias", "output_blocks.8.1.proj_out.weight", "output_blocks.8.1.proj_out.bias", "output_bloc ks.9.1.norm.weight", "output_blocks.9.1.norm.bias", "output_blocks.9.1.qkv.weight", "output_blocks.9.1.qkv.bias", "output_blocks.9.1.proj_out.weight", "output_blocks.9.1.proj_out.bias", "output_blocks.10.1.norm.weight", "output_bloc ks.10.1.norm.bias", "output_blocks.10.1.qkv.weight", "output_blocks.10.1.qkv.bias", "output_blocks.10.1.proj_out.weight", "output_blocks.10.1.proj_out.bias", "output_blocks.11.1.norm.weight", "outputblocks.11.1.norm.bias", "output blocks.11.1.qkv.weight", "output_blocks.11.1.qkv.bias", "output_blocks.11.1.proj_out.weight", "output_blocks.11.1.proj_out.bias", "output_blocks.11.2.in_layers.0.weight", "output_blocks.11.2.in_layers.0.bias", "output_blocks.11.2.in _layers.2.weight", "output_blocks.11.2.in_layers.2.bias", "output_blocks.11.2.emb_layers.1.weight", "output_blocks.11.2.emb_layers.1.bias", "output_blocks.11.2.out_layers.0.weight", "output_blocks.11.2.out_layers.0.bias", "output_bl ocks.11.2.out_layers.3.weight", "output_blocks.11.2.out_layers.3.bias". Unexpected key(s) in state_dict: "input_blocks.4.0.op.weight", "input_blocks.4.0.op.bias", "input_blocks.8.0.op.weight", "input_blocks.8.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "output_blocks. 3.2.conv.weight", "output_blocks.3.2.conv.bias", "output_blocks.7.2.conv.weight", "output_blocks.7.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias". size mismatch for out.2.weight: copying a param with shape torch.Size([3, 128, 3, 3]) from checkpoint, the shape in current model is torch.Size([6, 128, 3, 3]). size mismatch for out.2.bias: copying a param with shape torch.Size([3]) from checkpoint, the shape in current model is torch.Size([6]).

@fangtun Hi, i have the same issue. Have you solved it?

seungwooham commented 7 months ago

Hi, I also have a same issue.

Traceback (most recent call last): File "test.py", line 180, in main(conf_arg) File "test.py", line 69, in main model.load_state_dict( File "/usr/local/lib/python3.8/dist-packages/torch/nn/modules/module.py", line 2001, in load_state_dict raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format( RuntimeError: Error(s) in loading state_dict for UNetModel: Missing key(s) in state_dict: "input_blocks.3.0.in_layers.0.weight", "input_blocks.3.0.in_layers.0.bias", "input_blocks.3.0.in_layers.2.weight", "input_blocks.3.0.in_layers.2.bias", "input_blocks.3.0.emb_layers.1.weight", "input_blocks.3.0.emb_layers.1.bias", "input_blocks.3.0.out_layers.0.weight", "input_blocks.3.0.out_layers.0.bias", "input_blocks.3.0.out_layers.3.weight", "input_blocks.3.0.out_layers.3.bias", "input_blocks.6.0.in_layers.0.weight", "input_blocks.6.0.in_layers.0.bias", "input_blocks.6.0.in_layers.2.weight", "input_blocks.6.0.in_layers.2.bias", "input_blocks.6.0.emb_layers.1.weight", "input_blocks.6.0.emb_layers.1.bias", "input_blocks.6.0.out_layers.0.weight", "input_blocks.6.0.out_layers.0.bias", "input_blocks.6.0.out_layers.3.weight", "input_blocks.6.0.out_layers.3.bias", "input_blocks.9.0.in_layers.0.weight", "input_blocks.9.0.in_layers.0.bias", "input_blocks.9.0.in_layers.2.weight", "input_blocks.9.0.in_layers.2.bias", "input_blocks.9.0.emb_layers.1.weight", "input_blocks.9.0.emb_layers.1.bias", "input_blocks.9.0.out_layers.0.weight", "input_blocks.9.0.out_layers.0.bias", "input_blocks.9.0.out_layers.3.weight", "input_blocks.9.0.out_layers.3.bias", "input_blocks.12.0.in_layers.0.weight", "input_blocks.12.0.in_layers.0.bias", "input_blocks.12.0.in_layers.2.weight", "input_blocks.12.0.in_layers.2.bias", "input_blocks.12.0.emb_layers.1.weight", "input_blocks.12.0.emb_layers.1.bias", "input_blocks.12.0.out_layers.0.weight", "input_blocks.12.0.out_layers.0.bias", "input_blocks.12.0.out_layers.3.weight", "input_blocks.12.0.out_layers.3.bias", "input_blocks.15.0.in_layers.0.weight", "input_blocks.15.0.in_layers.0.bias", "input_blocks.15.0.in_layers.2.weight", "input_blocks.15.0.in_layers.2.bias", "input_blocks.15.0.emb_layers.1.weight", "input_blocks.15.0.emb_layers.1.bias", "input_blocks.15.0.out_layers.0.weight", "input_blocks.15.0.out_layers.0.bias", "input_blocks.15.0.out_layers.3.weight", "input_blocks.15.0.out_layers.3.bias", "output_blocks.2.2.in_layers.0.weight", "output_blocks.2.2.in_layers.0.bias", "output_blocks.2.2.in_layers.2.weight", "output_blocks.2.2.in_layers.2.bias", "output_blocks.2.2.emb_layers.1.weight", "output_blocks.2.2.emb_layers.1.bias", "output_blocks.2.2.out_layers.0.weight", "output_blocks.2.2.out_layers.0.bias", "output_blocks.2.2.out_layers.3.weight", "output_blocks.2.2.out_layers.3.bias", "output_blocks.5.2.in_layers.0.weight", "output_blocks.5.2.in_layers.0.bias", "output_blocks.5.2.in_layers.2.weight", "output_blocks.5.2.in_layers.2.bias", "output_blocks.5.2.emb_layers.1.weight", "output_blocks.5.2.emb_layers.1.bias", "output_blocks.5.2.out_layers.0.weight", "output_blocks.5.2.out_layers.0.bias", "output_blocks.5.2.out_layers.3.weight", "output_blocks.5.2.out_layers.3.bias", "output_blocks.8.2.in_layers.0.weight", "output_blocks.8.2.in_layers.0.bias", "output_blocks.8.2.in_layers.2.weight", "output_blocks.8.2.in_layers.2.bias", "output_blocks.8.2.emb_layers.1.weight", "output_blocks.8.2.emb_layers.1.bias", "output_blocks.8.2.out_layers.0.weight", "output_blocks.8.2.out_layers.0.bias", "output_blocks.8.2.out_layers.3.weight", "output_blocks.8.2.out_layers.3.bias", "output_blocks.11.1.in_layers.0.weight", "output_blocks.11.1.in_layers.0.bias", "output_blocks.11.1.in_layers.2.weight", "output_blocks.11.1.in_layers.2.bias", "output_blocks.11.1.emb_layers.1.weight", "output_blocks.11.1.emb_layers.1.bias", "output_blocks.11.1.out_layers.0.weight", "output_blocks.11.1.out_layers.0.bias", "output_blocks.11.1.out_layers.3.weight", "output_blocks.11.1.out_layers.3.bias", "output_blocks.14.1.in_layers.0.weight", "output_blocks.14.1.in_layers.0.bias", "output_blocks.14.1.in_layers.2.weight", "output_blocks.14.1.in_layers.2.bias", "output_blocks.14.1.emb_layers.1.weight", "output_blocks.14.1.emb_layers.1.bias", "output_blocks.14.1.out_layers.0.weight", "output_blocks.14.1.out_layers.0.bias", "output_blocks.14.1.out_layers.3.weight", "output_blocks.14.1.out_layers.3.bias". Unexpected key(s) in state_dict: "input_blocks.3.0.op.weight", "input_blocks.3.0.op.bias", "input_blocks.6.0.op.weight", "input_blocks.6.0.op.bias", "input_blocks.9.0.op.weight", "input_blocks.9.0.op.bias", "input_blocks.12.0.op.weight", "input_blocks.12.0.op.bias", "input_blocks.15.0.op.weight", "input_blocks.15.0.op.bias", "output_blocks.2.2.conv.weight", "output_blocks.2.2.conv.bias", "output_blocks.5.2.conv.weight", "output_blocks.5.2.conv.bias", "output_blocks.8.2.conv.weight", "output_blocks.8.2.conv.bias", "output_blocks.11.1.conv.weight", "output_blocks.11.1.conv.bias", "output_blocks.14.1.conv.weight", "output_blocks.14.1.conv.bias".

@handsomecong001 Have you tried it?

seungwooham commented 7 months ago

It might not work perfect, based on the arguments. However, in my case updating UNetModel and adding SiLU in the import section worked.

`class UNetModel(nn.Module): """ The full UNet model with attention and timestep embedding.

:param in_channels: channels in the input Tensor.
:param model_channels: base channel count for the model.
:param out_channels: channels in the output Tensor.
:param num_res_blocks: number of residual blocks per downsample.
:param attention_resolutions: a collection of downsample rates at which
    attention will take place. May be a set, list, or tuple.
    For example, if this contains 4, then at 4x downsampling, attention
    will be used.
:param dropout: the dropout probability.
:param channel_mult: channel multiplier for each level of the UNet.
:param conv_resample: if True, use learned convolutions for upsampling and
    downsampling.
:param dims: determines if the signal is 1D, 2D, or 3D.
:param num_classes: if specified (as an int), then this model will be
    class-conditional with `num_classes` classes.
:param use_checkpoint: use gradient checkpointing to reduce memory usage.
:param num_heads: the number of attention heads in each attention layer.
"""

def __init__(
    self,
    image_size,
    in_channels,
    model_channels,
    out_channels,
    num_res_blocks,
    attention_resolutions,
    dropout=0,
    channel_mult=(1, 2, 4, 8),
    conv_resample=True,
    dims=2,
    num_classes=None,
    use_checkpoint=False,
    use_fp16=False,
    num_heads=1,
    num_head_channels=-1,
    num_heads_upsample=-1,
    use_scale_shift_norm=False,
    resblock_updown=False,
    use_new_attention_order=False,
    conf=None
):
    super().__init__()

    if num_heads_upsample == -1:
        num_heads_upsample = num_heads

    self.image_size = image_size
    self.in_channels = in_channels
    self.model_channels = model_channels
    self.out_channels = out_channels
    self.num_res_blocks = num_res_blocks
    self.attention_resolutions = attention_resolutions
    self.dropout = dropout
    self.channel_mult = channel_mult
    self.conv_resample = conv_resample
    self.num_classes = num_classes
    self.use_checkpoint = use_checkpoint
    self.dtype = th.float16 if use_fp16 else th.float32
    self.num_heads = num_heads
    self.num_head_channels = num_head_channels
    self.num_heads_upsample = num_heads_upsample
    self.conf = conf

    time_embed_dim = model_channels * 4
    self.time_embed = nn.Sequential(
        linear(model_channels, time_embed_dim),
        SiLU(),
        linear(time_embed_dim, time_embed_dim),
    )

    if self.num_classes is not None:
        self.label_emb = nn.Embedding(num_classes, time_embed_dim)

    self.input_blocks = nn.ModuleList(
        [
            TimestepEmbedSequential(
                conv_nd(dims, in_channels, model_channels, 3, padding=1)
            )
        ]
    )
    input_block_chans = [model_channels]
    ch = model_channels
    ds = 1
    for level, mult in enumerate(channel_mult):
        for _ in range(num_res_blocks):
            layers = [
                ResBlock(
                    ch,
                    time_embed_dim,
                    dropout,
                    out_channels=mult * model_channels,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                )
            ]
            ch = mult * model_channels
            if ds in attention_resolutions:
                layers.append(
                    AttentionBlock(
                        ch,
                        use_checkpoint=use_checkpoint,
                        num_heads=num_heads,
                        num_head_channels=num_head_channels,
                        use_new_attention_order=use_new_attention_order,
                    )
                )
            self.input_blocks.append(TimestepEmbedSequential(*layers))
            input_block_chans.append(ch)
        if level != len(channel_mult) - 1:
            self.input_blocks.append(
                TimestepEmbedSequential(Downsample(ch, conv_resample, dims=dims))
            )
            input_block_chans.append(ch)
            ds *= 2

    self.middle_block = TimestepEmbedSequential(
        ResBlock(
            ch,
            time_embed_dim,
            dropout,
            dims=dims,
            use_checkpoint=use_checkpoint,
            use_scale_shift_norm=use_scale_shift_norm,
        ),
        AttentionBlock(ch,
                       use_checkpoint=use_checkpoint,
                       num_heads=num_heads,
                       num_head_channels=num_head_channels,
                       use_new_attention_order=use_new_attention_order,
        ),
        ResBlock(
            ch,
            time_embed_dim,
            dropout,
            dims=dims,
            use_checkpoint=use_checkpoint,
            use_scale_shift_norm=use_scale_shift_norm,
        ),
    )

    self.output_blocks = nn.ModuleList([])
    for level, mult in list(enumerate(channel_mult))[::-1]:
        for i in range(num_res_blocks + 1):
            layers = [
                ResBlock(
                    ch + input_block_chans.pop(),
                    time_embed_dim,
                    dropout,
                    out_channels=model_channels * mult,
                    dims=dims,
                    use_checkpoint=use_checkpoint,
                    use_scale_shift_norm=use_scale_shift_norm,
                )
            ]
            ch = model_channels * mult
            if ds in attention_resolutions:
                layers.append(
                    AttentionBlock(
                        ch,
                        use_checkpoint=use_checkpoint,
                        num_heads=num_heads_upsample,
                        num_head_channels=num_head_channels,
                        use_new_attention_order=use_new_attention_order,
                    )
                )
            if level and i == num_res_blocks:
                layers.append(Upsample(ch, conv_resample, dims=dims))
                ds //= 2
            self.output_blocks.append(TimestepEmbedSequential(*layers))

    self.out = nn.Sequential(
        normalization(ch),
        SiLU(),
        zero_module(conv_nd(dims, model_channels, out_channels, 3, padding=1)),
    )

def convert_to_fp16(self):
    """
    Convert the torso of the model to float16.
    """
    self.input_blocks.apply(convert_module_to_f16)
    self.middle_block.apply(convert_module_to_f16)
    self.output_blocks.apply(convert_module_to_f16)

def convert_to_fp32(self):
    """
    Convert the torso of the model to float32.
    """
    self.input_blocks.apply(convert_module_to_f32)
    self.middle_block.apply(convert_module_to_f32)
    self.output_blocks.apply(convert_module_to_f32)

@property
def inner_dtype(self):
    """
    Get the dtype used by the torso of the model.
    """
    return next(self.input_blocks.parameters()).dtype

def forward(self, x, timesteps, y=None, gt=None, **kwargs):
    """
    Apply the model to an input batch.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: an [N x C x ...] Tensor of outputs.
    """
    assert (y is not None) == (
        self.num_classes is not None
    ), "must specify y if and only if the model is class-conditional"

    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))

    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)

    h = x.type(self.inner_dtype)
    for module in self.input_blocks:
        h = module(h, emb)
        hs.append(h)
    h = self.middle_block(h, emb)
    for module in self.output_blocks:
        cat_in = th.cat([h, hs.pop()], dim=1)
        h = module(cat_in, emb)
    h = h.type(x.dtype)
    return self.out(h)

def get_feature_vectors(self, x, timesteps, y=None):
    """
    Apply the model and return all of the intermediate tensors.

    :param x: an [N x C x ...] Tensor of inputs.
    :param timesteps: a 1-D batch of timesteps.
    :param y: an [N] Tensor of labels, if class-conditional.
    :return: a dict with the following keys:
             - 'down': a list of hidden state tensors from downsampling.
             - 'middle': the tensor of the output of the lowest-resolution
                         block in the model.
             - 'up': a list of hidden state tensors from upsampling.
    """
    hs = []
    emb = self.time_embed(timestep_embedding(timesteps, self.model_channels))
    if self.num_classes is not None:
        assert y.shape == (x.shape[0],)
        emb = emb + self.label_emb(y)
    result = dict(down=[], up=[])
    h = x.type(self.inner_dtype)
    for module in self.input_blocks:
        h = module(h, emb)
        hs.append(h)
        result["down"].append(h.type(x.dtype))
    h = self.middle_block(h, emb)
    result["middle"] = h.type(x.dtype)
    for module in self.output_blocks:
        cat_in = th.cat([h, hs.pop()], dim=1)
        h = module(cat_in, emb)
        result["up"].append(h.type(x.dtype))
    return result

class SuperResModel(UNetModel): """ A UNetModel that performs super-resolution.

Expects an extra kwarg `low_res` to condition on a low-resolution image.
"""

def __init__(self, image_size, in_channels, *args, **kwargs):
    super().__init__(image_size, in_channels * 2, *args, **kwargs)

def forward(self, x, timesteps, low_res=None, **kwargs):
    _, _, new_height, new_width = x.shape
    upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
    x = th.cat([x, upsampled], dim=1)
    return super().forward(x, timesteps, **kwargs)

def get_feature_vectors(self, x, timesteps, low_res=None, **kwargs):
    _, new_height, new_width, _ = x.shape
    upsampled = F.interpolate(low_res, (new_height, new_width), mode="bilinear")
    x = th.cat([x, upsampled], dim=1)
    return super().get_feature_vectors(x, timesteps, **kwargs)

`