BR-IDL / PaddleViT

:robot: PaddleViT: State-of-the-art Visual Transformer and MLP Models for PaddlePaddle 2.0+
https://github.com/BR-IDL/PaddleViT
Apache License 2.0
1.22k stars 318 forks source link

I want paddle can create an api nn.SwinT,inputs and outputs all equal nn.Conv2D #130

Open tensorfly-gpu opened 2 years ago

tensorfly-gpu commented 2 years ago

我想paddle可以增加一个为nn.SwinT的接口,输入和输出完全同卷积。 它可以完全替换掉任意基于卷积模型中的二维卷积层,因为输入和输出形状完全同卷积,因此十分方便。 在卷积和注意力混用的模型中会更加方便(尤其目前很多这方面的研究),非常希望飞桨能将此接口加入到nn.SwinT中,并进行优化。 我观察到虽然Vit在CV领域日渐成熟,但是大家对于他的使用还是比较陌生,更熟悉的还是CNN,所以将SwinT做成和CNN具有相同输入输出的接口将极大方便日常的模型编写和训练。

另一个方面是,我们可以在模型重要的几层用上SwinT,而其他层用卷积,兼顾效率和精度,但是现在没有这样方便的接口供大家方便的去这样的编程。包括Vit用于分类的模型,一般迁移到输出为图片任务的时候,大家很容易迷失,而CNN用于分割检测非常成熟,所以,这也是另一个比较重要的原因。

我对paddleVit中的SwinV1修改了注意力——>改成余弦注意力,修改了前归一化——>后归一化,实现了SwinV2的两个内容,目前对于相对位置偏置没有做修改,因为主要是当作和CNN一样方便的层来使用。

代码如下,我对这个接口还是抱有很大期望的,由于本人代码水平有限,如果paddle官方开发也对这个接口感兴趣的话,可以优化以下代码。感激不尽!

import numpy as np import paddle import paddle.nn as nn

class DropPath(nn.Layer): """ DropPath class 原理 :字如其名,Drop Path就是随机将深度学习网络中的多分支结构随机删除。 功能 :一般可以作为正则化手段加入网络,但是会增加网络训练的难度。尤其是在NAS问题中,如果设置的drop prob过高,模型甚至有可能不收敛。 """ def init(self, drop_prob=None): super().init() self.drop_prob = drop_prob

def drop_path(self, inputs):
    """drop path op
    Args:
        input: tensor with arbitrary shape
        drop_prob: float number of drop path probability, default: 0.0
        training: bool, if current mode is training, default: False
    Returns:
        output: output tensor after drop path
    """
    # if prob is 0 or eval mode, return original input
    if self.drop_prob == 0. or not self.training:
        return inputs
    keep_prob = 1 - self.drop_prob
    keep_prob = paddle.to_tensor(keep_prob, dtype='float32')
    shape = (inputs.shape[0], ) + (1, ) * (inputs.ndim - 1)  # shape=(N, 1, 1, 1)
    random_tensor = keep_prob + paddle.rand(shape, dtype=inputs.dtype)
    random_tensor = random_tensor.floor() # mask
    output = inputs.divide(keep_prob) * random_tensor # divide is to keep same output expectation
    return output

def forward(self, inputs):
    return self.drop_path(inputs)

class Identity(nn.Layer): """ Identity layer The output of this layer is the input without any change. Use this layer to avoid if condition in some forward methods """ def init(self): super(Identity, self).init() def forward(self, x): return x

class PatchEmbedding(nn.Layer): """Patch Embeddings Apply patch embeddings on input images. Embeddings is implemented using a Conv2D op. Attributes: image_size: int, input image size, default: 224 patch_size: int, size of patch, default: 4 in_channels: int, input image channels, default: 3 embed_dim: int, embedding dimension, default: 96 """

def __init__(self, image_size=224, patch_size=4, in_channels=3, embed_dim=96):
    super().__init__()
    image_size = (image_size, image_size) # TODO: add to_2tuple
    patch_size = (patch_size, patch_size)
    patches_resolution = [image_size[0]//patch_size[0], image_size[1]//patch_size[1]]
    self.image_size = image_size
    self.patch_size = patch_size
    self.patches_resolution = patches_resolution
    self.num_patches = patches_resolution[0] * patches_resolution[1]
    self.in_channels = in_channels
    self.embed_dim = embed_dim
    self.patch_embed = nn.Conv2D(in_channels=in_channels,
                                 out_channels=embed_dim,
                                 kernel_size=patch_size,
                                 stride=patch_size)

    w_attr, b_attr = self._init_weights_layernorm()
    self.norm = nn.LayerNorm(embed_dim,
                             weight_attr=w_attr,
                             bias_attr=b_attr)

def _init_weights_layernorm(self):
    weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
    bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
    return weight_attr, bias_attr

def forward(self, x):
    x = self.patch_embed(x) # [batch, embed_dim, h, w] h,w = patch_resolution
    x = x.flatten(start_axis=2, stop_axis=-1) # [batch, embed_dim, h*w] h*w = num_patches
    x = x.transpose([0, 2, 1]) # [batch, h*w, embed_dim]
    x = self.norm(x) # [batch, num_patches, embed_dim]
    return x

class PatchMerging(nn.Layer): """ Patch Merging class Merge multiple patch into one path and keep the out dim. Spefically, merge adjacent 2x2 patches(dim=C) into 1 patch. The concat dim 4C is rescaled to 2C Attributes: input_resolution: tuple of ints, the size of input dim: dimension of single patch reduction: nn.Linear which maps 4C to 2C dim norm: nn.LayerNorm, applied after linear layer. """

def __init__(self, input_resolution, dim):
    super(PatchMerging, self).__init__()
    self.input_resolution = input_resolution
    self.dim = dim
    w_attr_1, b_attr_1 = self._init_weights()
    self.reduction = nn.Linear(4 * dim,
                               2 * dim,
                               weight_attr=w_attr_1,
                               bias_attr=False)

    w_attr_2, b_attr_2 = self._init_weights_layernorm()
    self.norm = nn.LayerNorm(4*dim,
                             weight_attr=w_attr_2,
                             bias_attr=b_attr_2)

def _init_weights_layernorm(self):
    weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
    bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
    return weight_attr, bias_attr

def _init_weights(self):
    weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
    bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
    return weight_attr, bias_attr

def forward(self, x):
    h, w = self.input_resolution
    b, _, c = x.shape
    x = x.reshape([b, h, w, c])

    x0 = x[:, 0::2, 0::2, :] # [B, H/2, W/2, C]
    x1 = x[:, 1::2, 0::2, :] # [B, H/2, W/2, C]
    x2 = x[:, 0::2, 1::2, :] # [B, H/2, W/2, C]
    x3 = x[:, 1::2, 1::2, :] # [B, H/2, W/2, C]
    x = paddle.concat([x0, x1, x2, x3], -1) #[B, H/2, W/2, 4*C]
    x = x.reshape([b, -1, 4*c]) # [B, H/2*W/2, 4*C]

    x = self.norm(x)
    x = self.reduction(x)   # [B, H/2*W/2, 2*C]

    return x

class Mlp(nn.Layer): """ MLP module Impl using nn.Linear and activation is GELU, dropout is applied. Ops: fc -> act -> dropout -> fc -> dropout Attributes: fc1: nn.Linear fc2: nn.Linear act: GELU dropout1: dropout after fc1 dropout2: dropout after fc2 """

def __init__(self, in_features, hidden_features, dropout):
    super(Mlp, self).__init__()
    w_attr_1, b_attr_1 = self._init_weights()
    self.fc1 = nn.Linear(in_features,
                         hidden_features,
                         weight_attr=w_attr_1,
                         bias_attr=b_attr_1)

    w_attr_2, b_attr_2 = self._init_weights()
    self.fc2 = nn.Linear(hidden_features,
                         in_features,
                         weight_attr=w_attr_2,
                         bias_attr=b_attr_2)
    self.act = nn.GELU()
    self.dropout = nn.Dropout(dropout)

def _init_weights(self):
    weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
    bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
    return weight_attr, bias_attr

def forward(self, x):
    x = self.fc1(x)
    x = self.act(x)
    x = self.dropout(x)
    x = self.fc2(x)
    x = self.dropout(x)
    return x

def windows_partition(x, window_size): """ partite windows into window_size x window_size Args: x: Tensor, shape=[b, h, w, c] window_size: int, window size Returns: x: Tensor, shape=[num_windows*b, window_size, window_size, c] """

B, H, W, C = x.shape
x = x.reshape([B, H//window_size, window_size, W//window_size, window_size, C]) # [bs,num_window,window_size,num_window,window_size,C]
x = x.transpose([0, 1, 3, 2, 4, 5])     # [bs,num_window,num_window,window_size,window_Size,C]
x = x.reshape([-1, window_size, window_size, C]) #(bs*num_windows,window_size, window_size, C)

return x

def windows_reverse(windows, window_size, H, W): """ Window reverse Args: windows: (n_windows * B, window_size, window_size, C) window_size: (int) window size H: (int) height of image W: (int) width of image Returns: x: (B, H, W, C) """

B = int(windows.shape[0] / (H * W / window_size / window_size))
x = windows.reshape([B, H // window_size, W // window_size, window_size, window_size, -1]) # [bs,num_window,num_window,window_size,window_Size,C]
x = x.transpose([0, 1, 3, 2, 4, 5]) # [bs,num_window,window_size,num_window,window_size,C]
x = x.reshape([B, H, W, -1])  #(bs,num_windows*window_size, num_windows*window_size, C)
return x

class WindowAttention(nn.Layer): """Window based multihead attention, with relative position bias. Both shifted window and non-shifted window are supported. Attributes: dim: int, input dimension (channels) window_size: int, height and width of the window num_heads: int, number of attention heads qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None attention_dropout: float, dropout of attention dropout: float, dropout for output """

def __init__(self,
             dim,
             window_size,
             num_heads,
             qkv_bias=True,
             qk_scale=None,
             attention_dropout=0.,
             dropout=0.):
    super(WindowAttention, self).__init__()
    self.window_size = window_size
    self.num_heads = num_heads
    self.dim = dim
    self.dim_head = dim // num_heads
    self.scale = qk_scale or self.dim_head ** -0.5

    self.relative_position_bias_table = paddle.create_parameter(
        shape=[(2 * window_size[0] -1) * (2 * window_size[1] - 1), num_heads],
        dtype='float32',
        default_initializer=paddle.nn.initializer.TruncatedNormal(std=.02))

    # relative position index for each token inside window
    coords_h = paddle.arange(self.window_size[0])
    coords_w = paddle.arange(self.window_size[1])
    coords = paddle.stack(paddle.meshgrid([coords_h, coords_w])) # [2, window_h, window_w]
    coords_flatten = paddle.flatten(coords, 1) # [2, window_h * window_w]
    # 2, window_h * window_w, window_h * window_w
    relative_coords = coords_flatten.unsqueeze(2) - coords_flatten.unsqueeze(1)
    # winwod_h*window_w, window_h*window_w, 2
    relative_coords = relative_coords.transpose([1, 2, 0])
    relative_coords[:, :, 0] += self.window_size[0] - 1
    relative_coords[:, :, 1] += self.window_size[1] - 1
    relative_coords[:, :, 0] *= 2* self.window_size[1] - 1
    # [window_size * window_size, window_size*window_size]
    relative_position_index = relative_coords.sum(-1)
    self.register_buffer("relative_position_index", relative_position_index)

    w_attr_1, b_attr_1 = self._init_weights()
    self.qkv = nn.Linear(dim,
                         dim * 3,
                         weight_attr=w_attr_1,
                         bias_attr=b_attr_1 if qkv_bias else False)

    self.attn_dropout = nn.Dropout(attention_dropout)

    w_attr_2, b_attr_2 = self._init_weights()
    self.proj = nn.Linear(dim,
                          dim,
                          weight_attr=w_attr_2,
                          bias_attr=b_attr_2)
    self.proj_dropout = nn.Dropout(dropout)
    self.softmax = nn.Softmax(axis=-1)

    # Swin-T v2, Scaled cosine attention
    self.tau = paddle.create_parameter(
        shape = [num_heads, window_size[0]*window_size[1], window_size[0]*window_size[1]],
        dtype='float32',
        default_initializer=paddle.nn.initializer.Constant(1))

def _init_weights(self):
    weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.TruncatedNormal(std=.02))
    bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
    return weight_attr, bias_attr

def transpose_multihead(self, x):
    new_shape = x.shape[:-1] + [self.num_heads, self.dim_head]
    x = x.reshape(new_shape)
    x = x.transpose([0, 2, 1, 3])
    return x

def get_relative_pos_bias_from_pos_index(self):
    # relative_position_bias_table is a ParamBase object
    # https://github.com/PaddlePaddle/Paddle/blob/067f558c59b34dd6d8626aad73e9943cf7f5960f/python/paddle/fluid/framework.py#L5727
    table = self.relative_position_bias_table # N x num_heads
    # index is a tensor
    index = self.relative_position_index.reshape([-1]) # window_h*window_w * window_h*window_w
    # NOTE: paddle does NOT support indexing Tensor by a Tensor
    relative_position_bias = paddle.index_select(x=table, index=index)
    return relative_position_bias

def forward(self, x, mask=None):
    qkv = self.qkv(x).chunk(3, axis=-1)     # {list:3}
    q, k, v = map(self.transpose_multihead, qkv)       

    #q = q * self.scale
    #attn = paddle.matmul(q, k, transpose_y=True)        #SwinV2,修改此处为余弦注意力

    # SwinV2, Scaled cosine attention
    qk = paddle.matmul(q, k, transpose_y=True)        
    q2 = paddle.multiply(q, q).sum(-1).sqrt().unsqueeze(3)
    k2 = paddle.multiply(k, k).sum(-1).sqrt().unsqueeze(3)
    attn = qk/paddle.clip(paddle.matmul(q2, k2, transpose_y=True), min=1e-6)
    attn = attn/paddle.clip(self.tau.unsqueeze(0), min=0.01)

    relative_position_bias = self.get_relative_pos_bias_from_pos_index() 

    relative_position_bias = relative_position_bias.reshape(
        [self.window_size[0] * self.window_size[1],
         self.window_size[0] * self.window_size[1],
         -1])       

    # nH, window_h*window_w, window_h*window_w
    relative_position_bias = relative_position_bias.transpose([2, 0, 1])  
    attn = attn + relative_position_bias.unsqueeze(0)   

    if mask is not None:
        nW = mask.shape[0]
        attn = attn.reshape(
            [x.shape[0] // nW, nW, self.num_heads, x.shape[1], x.shape[1]])
        attn += mask.unsqueeze(1).unsqueeze(0)
        attn = attn.reshape([-1, self.num_heads, x.shape[1], x.shape[1]])
        attn = self.softmax(attn)
    else:
        attn = self.softmax(attn)

    attn = self.attn_dropout(attn)  

    z = paddle.matmul(attn, v)      
    z = z.transpose([0, 2, 1, 3])
    new_shape = z.shape[:-2] + [self.dim]
    z = z.reshape(new_shape)
    z = self.proj(z)
    z = self.proj_dropout(z)    

    return z

class SwinTransformerBlock(nn.Layer): """Swin transformer block Contains window multi head self attention, droppath, mlp, norm and residual. Attributes: dim: int, input dimension (channels) input_resolution: int, input resoultion -->input_resolution: tuple, input resoultion

    num_heads: int, number of attention heads
    windos_size: int, window size, default: 7
    shift_size: int, shift size for SW-MSA, default: 0
    mlp_ratio: float, ratio of mlp hidden dim and input embedding dim, default: 4.
    qkv_bias: bool, if True, enable learnable bias to q,k,v, default: True
    qk_scale: float, override default qk scale head_dim**-0.5 if set, default: None
    dropout: float, dropout for output, default: 0.
    attention_dropout: float, dropout of attention, default: 0.
    droppath: float, drop path rate, default: 0.
"""

def __init__(self, dim, input_resolution, num_heads, window_size=7, shift_size=0,
             mlp_ratio=4., qkv_bias=True, qk_scale=None, dropout=0.,
             attention_dropout=0., droppath=0.):
    super(SwinTransformerBlock, self).__init__()
    self.dim = dim
    self.input_resolution = input_resolution
    self.num_heads = num_heads
    self.window_size = window_size
    self.shift_size = shift_size
    self.mlp_ratio = mlp_ratio
    if min(self.input_resolution) <= self.window_size:
        self.shift_size = 0
        self.window_size = min(self.input_resolution)

    w_attr_1, b_attr_1 = self._init_weights_layernorm()
    self.norm1 = nn.LayerNorm(dim,
                              weight_attr=w_attr_1,
                              bias_attr=b_attr_1)

    self.attn = WindowAttention(dim,
                                window_size=(self.window_size, self.window_size),
                                num_heads=num_heads,
                                qkv_bias=qkv_bias,
                                qk_scale=qk_scale,
                                attention_dropout=attention_dropout,
                                dropout=dropout)
    self.drop_path = DropPath(droppath) if droppath > 0. else None

    w_attr_2, b_attr_2 = self._init_weights_layernorm()
    self.norm2 = nn.LayerNorm(dim,
                              weight_attr=w_attr_2,
                              bias_attr=b_attr_2)

    self.mlp = Mlp(in_features=dim,
                   hidden_features=int(dim*mlp_ratio),
                   dropout=dropout)

    if self.shift_size > 0:
        H, W = self.input_resolution
        img_mask = paddle.zeros((1, H, W, 1))
        h_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        w_slices = (slice(0, -self.window_size),
                    slice(-self.window_size, -self.shift_size),
                    slice(-self.shift_size, None))
        cnt = 0
        for h in h_slices:
            for w in w_slices:
                img_mask[:, h, w, :] = cnt
                cnt += 1

        mask_windows = windows_partition(img_mask, self.window_size)
        mask_windows = mask_windows.reshape((-1, self.window_size * self.window_size))
        attn_mask = mask_windows.unsqueeze(1) - mask_windows.unsqueeze(2)
        attn_mask = paddle.where(attn_mask != 0,
                                 paddle.ones_like(attn_mask) * float(-100.0),   #这里,关于mask是否真的必要,这部分使整个代码变得复杂了极多
                                 attn_mask)                                     #有些时候,其实我们也想结合图像边缘之间的关系                                
        attn_mask = paddle.where(attn_mask == 0,                                #如果将-100设置为0网络也能work的话,Swin将大大减少代码量
                                 paddle.zeros_like(attn_mask),
                                 attn_mask)
    else:
        attn_mask = None

    self.register_buffer("attn_mask", attn_mask)

def _init_weights_layernorm(self):
    weight_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(1))
    bias_attr = paddle.ParamAttr(initializer=paddle.nn.initializer.Constant(0))
    return weight_attr, bias_attr

def forward(self, x):
    H, W = self.input_resolution
    B, L, C = x.shape
    h = x

    #x = self.norm1(x)   # [bs,H*W,C]   #后归一化,移到做完attantion之后

    new_shape = [B, H, W, C]
    x = x.reshape(new_shape) # [bs,H,W,C]

    if self.shift_size > 0:
        shifted_x = paddle.roll(x,
                                shifts=(-self.shift_size, -self.shift_size),
                                axis=(1, 2))        # [bs,H,W,C]
    else:
        shifted_x = x

    x_windows = windows_partition(shifted_x, self.window_size)  # [bs*num_windows,7,7,C]
    x_windows = x_windows.reshape([-1, self.window_size * self.window_size, C]) # [bs*num_windows,7*7,C]

    attn_windows = self.attn(x_windows, mask=self.attn_mask)    # [bs*num_windows,7*7,C]
    attn_windows = attn_windows.reshape([-1, self.window_size, self.window_size, C])    # [bs*num_windows,7,7,C]

    shifted_x = windows_reverse(attn_windows, self.window_size, H, W)   # [bs,H,W,C] 

    # reverse cyclic shift
    if self.shift_size > 0:
        x = paddle.roll(shifted_x,
                        shifts=(self.shift_size, self.shift_size),
                        axis=(1, 2))
    else:
        x = shifted_x

    x = x.reshape([B, H*W, C])      # [bs,H*W,C] 
    x = self.norm1(x)   # [bs,H*W,C]    #移到这里

    if self.drop_path is not None:
        x = h + self.drop_path(x)
    else:
        x = h + x
    h = x       # [bs,H*W,C]

    '''
    SwinV2,将此处修改为后归一化
    x = self.norm2(x)       # [bs,H*W,C]
    '''

    x = self.mlp(x)         # [bs,H*W,C]
    x = self.norm2(x)       #放在这里

    if self.drop_path is not None:
        x = h + self.drop_path(x)
    else:
        x = h + x
    return x

class SwinT(nn.Layer):

"""
the input shape and output shape is euqal to Conv2D
use this module can replace Conv2D by SwinT in any scene
Attribute:
input_channels:the channels of inputs
resolution:the only different from cnn, it need resolution to detemine its forward process
attention_parameters:{num_heads, window_size, mlp_ratio, qkv_bias, qk_scale, dropout, attention_dropout, droppath}
downsample:like cnn pooling, default:False

没有使用output_channels因为,transformer本身提取特征能力很强,另外下采样会使用patch_merging进行维度*2
另外由于其输入和输出与CNN完全一致,所以扩大通道可以直接在前面加卷积。
虽然有那么一点点不一样,但是因为我们搭建网络还是会把图像大小这样重要信息记一下,所以问题不是很大
重要的是,它可以完全替换掉任意基于卷积模型中的二维卷积层,因为输入和输出形状完全同卷积,因此十分方便
在卷积和注意力混用的模型中会更加方便,非常希望飞桨能将此接口加入到nn.SwinT中,并进行优化
"""

def __init__(self, in_channels, input_resolution, num_heads, window_size,
             mlp_ratio=4., qkv_bias=True, qk_scale=None, dropout=0.,
             attention_dropout=0., droppath=0., downsample=False):
    super().__init__()
    self.dim = in_channels
    self.input_resolution = input_resolution

    self.blocks = nn.LayerList()
    for i in range(2):
        self.blocks.append(
            SwinTransformerBlock(
                dim=in_channels, input_resolution=input_resolution,
                num_heads=num_heads, window_size=window_size,
                shift_size=0 if (i % 2 == 0) else window_size // 2,
                mlp_ratio=mlp_ratio,
                qkv_bias=qkv_bias, qk_scale=qk_scale,
                dropout=dropout, attention_dropout=attention_dropout,
                droppath=droppath[i] if isinstance(droppath, list) else droppath))

    if downsample:
        self.downsample = PatchMerging(input_resolution, dim=in_channels)
    else:
        self.downsample = None

def forward(self, x):
    B, C, H, W = x.shape
    x = x.reshape([B, C, H * W])
    x = x.transpose((0, 2, 1))    #[B, H*W, C]

    for block in self.blocks:
        x = block(x)                
    if self.downsample is not None:
        x = self.downsample(x)      
        x = x.transpose((0, 2, 1))    #[B, C * 2, H//2 * W//2]
        x = x.reshape([B, C * 2, H//2, W//2])
    else:
        x = x.transpose((0, 2, 1))
        x = x.reshape([B, C, H, W])
    return x

测试代码如下

tmp = paddle.to_tensor(np.random.rand(2, 48, 224, 224), dtype='float32') print(tmp.shape) sts = SwinT(in_channels=48, input_resolution=(224,224), num_heads=8, window_size=8, qkv_bias=False, qk_scale=None, dropout=0.1, attention_dropout=0.1, droppath=0.1,downsample=True) out = sts(tmp) print(out.shape)

输出结果如下

[2, 48, 224, 224] [2, 96, 112, 112]

xperzy commented 2 years ago

Thanks for the suggestion! We are evaluating this feature!

tensorfly-gpu commented 2 years ago

感谢老师的回信。我对这个接口抱有很大的期望,已经利用该接口创建了SwinResnet并对其进行了验证。项目地址SwinT-让Swin-Transformer的使用变得和CNN一样方便快捷! - 飞桨AI Studio - 人工智能学习实训社区 (baidu.com) 之前我觉得我很难使用paddle编写Swin-Transformer的代码,但是看了您的课程之后,虽然从头编写我觉得我还是做不到,但是至少可以找出程序出错的地方并进行修改了,非常感谢!

------------------ 原始邮件 ------------------ 发件人: @.>; 发送时间: 2021年12月22日(星期三) 下午4:31 收件人: @.>; 抄送: @.>; @.>; 主题: Re: [BR-IDL/PaddleViT] I want paddle can create an api nn.SwinT,inputs and outputs all equal nn.Conv2D (Issue #130)

Thanks for the suggestion! We are evaluating this feature!

— Reply to this email directly, view it on GitHub, or unsubscribe. Triage notifications on the go with GitHub Mobile for iOS or Android. You are receiving this because you authored the thread.Message ID: @.***>