DaiShiResearch / TransNeXt

[CVPR 2024] Code release for TransNeXt model
Apache License 2.0
289 stars 12 forks source link

Error when using non-standard image sizes (e.g., 138x138) with TransNeXt Tiny model #12

Open Coolog opened 1 month ago

Coolog commented 1 month ago

I am currently working with your TransNeXt model, specifically the tiny variant. When I use an image size that is not 224 (e.g., 138), I encounter an error in the attention_native.py file. The specific error occurs at the following line:

q_norm_scaled = (q_norm + self.query_embedding) F.softplus(self.temperature) self.seq_length_scale

The error message is: RuntimeError: The size of tensor a (256) must match the size of tensor b (289) at non-singleton dimension 2

It seems that the issue is related to the input image size not being compatible with the model parameters. I have attempted to adjust the self.seq_length_scale, the get_seqlen_and_mask function, and the line in transnext.py that specifies img_size // (2 ** (i + 2), but these attempts have not resolved the issue.

Could you please provide guidance on how to modify the model to accept input sizes such as 138x138?

Thank you for your assistance.

DaiShiResearch commented 1 month ago

Thank you for your question. Currently, the code assumes the input resolution dimensions are multiples of 32, which is a common practice for preprocessing input image sizes in training code for tasks such as detection and segmentation. To avoid modifying the code, you can resize your image to the nearest multiple of 32, such as 128x128, which should work fine.

If you want to use arbitrary input resolutions, you will need to modify the code to pre-calculate the actual feature sizes and pool sizes for each layer. Below is a modified version of the model initialization that accommodates different input sizes:

class TransNeXt(nn.Module):
    '''
    The parameter "img size" is primarily utilized for generating relative spatial coordinates,
    which are used to compute continuous relative positional biases. As this TransNeXt implementation does not support multi-scale inputs,
    it is recommended to set the "img size" parameter to a value that is exactly the same as the resolution of the inference images.
    It is not advisable to set the "img size" parameter to a value exceeding 800x800.
    The "pretrain size" refers to the "img size" used during the initial pre-training phase,
    which is used to scale the relative spatial coordinates for better extrapolation by the MLP.
    For models trained on ImageNet-1K at a resolution of 224x224,
    as well as downstream task models fine-tuned based on these pre-trained weights,
    the "pretrain size" parameter should be set to 224x224.
    '''

    def __init__(self, img_size=224, pretrain_size=None, window_size=[3, 3, 3, None],
                 patch_size=16, in_chans=3, num_classes=1000, embed_dims=[64, 128, 256, 512],
                 num_heads=[1, 2, 4, 8], mlp_ratios=[4, 4, 4, 4], qkv_bias=False, drop_rate=0.,
                 attn_drop_rate=0., drop_path_rate=0., norm_layer=nn.LayerNorm,
                 depths=[3, 4, 6, 3], sr_ratios=[8, 4, 2, 1], num_stages=4, fixed_pool_size=None):
        super().__init__()
        self.num_classes = num_classes
        self.depths = depths
        self.num_stages = num_stages
        pretrain_size = pretrain_size or img_size

        dpr = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]  # stochastic depth decay rule
        cur = 0

        def calculate_feature_size(input_size, kernel_size, stride, padding):
            return (input_size + 2 * padding - kernel_size) // stride + 1

        layers = [
            {'kernel_size': patch_size * 2 - 1, 'stride': patch_size},
            {'kernel_size': 3, 'stride': 2},
            {'kernel_size': 3, 'stride': 2},
            {'kernel_size': 3, 'stride': 2}
        ]

        feature_sizes = []

        input_size = img_size
        for i, layer in enumerate(layers):
            kernel_size = layer['kernel_size']
            stride = layer['stride']
            padding = kernel_size // 2
            output_size = calculate_feature_size(input_size, kernel_size, stride, padding)
            feature_sizes.append(output_size)
            input_size = output_size

        for i in range(num_stages):
            # Generate relative positional coordinate table and index for each stage to compute continuous relative positional bias.
            relative_pos_index, relative_coords_table = get_relative_position_cpb(
                query_size=to_2tuple(feature_sizes[i]),
                key_size=to_2tuple(feature_sizes[-1]) if (
                        fixed_pool_size is None or sr_ratios[i] == 1) else to_2tuple(fixed_pool_size),
                pretrain_size=to_2tuple(pretrain_size // (2 ** (i + 2))))

            self.register_buffer(f"relative_pos_index{i + 1}", relative_pos_index, persistent=False)
            self.register_buffer(f"relative_coords_table{i + 1}", relative_coords_table, persistent=False)

            patch_embed = OverlapPatchEmbed(patch_size=patch_size * 2 - 1 if i == 0 else 3,
                                            stride=patch_size if i == 0 else 2,
                                            in_chans=in_chans if i == 0 else embed_dims[i - 1],
                                            embed_dim=embed_dims[i])

            block = nn.ModuleList([Block(
                dim=embed_dims[i], input_resolution=to_2tuple(feature_sizes[i]), window_size=window_size[i],
                num_heads=num_heads[i], mlp_ratio=mlp_ratios[i], qkv_bias=qkv_bias,
                drop=drop_rate, attn_drop=attn_drop_rate, drop_path=dpr[cur + j], norm_layer=norm_layer,
                sr_ratio=sr_ratios[i], fixed_pool_size=fixed_pool_size or feature_sizes[-1])
                for j in range(depths[i])])
            norm = norm_layer(embed_dims[i])
            cur += depths[i]

            setattr(self, f"patch_embed{i + 1}", patch_embed)
            setattr(self, f"block{i + 1}", block)
            setattr(self, f"norm{i + 1}", norm)

        # classification head
        self.head = nn.Linear(embed_dims[3], num_classes) if num_classes > 0 else nn.Identity()

        for n, m in self.named_modules():
            self._init_weights(m, n)

This should allow the model to handle various input sizes by calculating the appropriate feature sizes dynamically.

Of course, your input resolution should not be too small; the feature size at stage 4 should be at least larger than 1x1.