Open lartpang opened 2 months ago
本文提出了一种 CNN 和 ViT 的混合架构,即 FasterViT。这样的混合架构可以快速生成高质量 token,然后基于 Transformer 块来进一步处理这些 token。其重点在于结合架构组合和高效的注意力模块的设计,从而优化 ViT 模型的计算效率,提高图像的吞吐率,加强对于高分辨率图像的适应能力。
关于模型设计,作者写了一段具有参考价值的经验性分析,并在这些见解的指导下设计了所提架构,所有阶段均可从加速计算硬件受益。
我们专注于 GPU 等主流现成硬件上的计算机视觉任务的最高吞吐量这在并行计算方面表现出色。在这种情况下,计算涉及一组流式多处理器(SM),以 CUDA 和 Tensor cores 作为计算单元。它需要频繁的数据传输来计算,并可能受到数据移动带宽的影响。因此,受计算限制的操作是计算受限的,而那些受内存传输限制的操作是内存受限的。需要两者之间仔细平衡才能最大化吞吐量。 在分层视觉模型中,中间表示的空间维度随着推理过程而缩小。 初始网络层大多具有较大的空间维度和较少的通道(例如 112x112x64),这使它们是内存受限的。这使得这里更适合计算密集型操作,例如密集卷积而不是会增加额外传输成本的深度/稀疏卷积。也不可用矩阵操作形式表示的操作,例如非线性操作,池化,批归一化。这类操作也是内存受限的,应减少使用。 相反,后面层的操作往往是计算受限的,且运算成本高昂。例如,分层 CNN 具有尺寸为 14x14 的特征图和高维卷积核。这为更具表达性的操作留下了空间,例如层归一化、SE 模块,或注意力,并对吞吐量影响相当小。
我们专注于 GPU 等主流现成硬件上的计算机视觉任务的最高吞吐量这在并行计算方面表现出色。在这种情况下,计算涉及一组流式多处理器(SM),以 CUDA 和 Tensor cores 作为计算单元。它需要频繁的数据传输来计算,并可能受到数据移动带宽的影响。因此,受计算限制的操作是计算受限的,而那些受内存传输限制的操作是内存受限的。需要两者之间仔细平衡才能最大化吞吐量。
在分层视觉模型中,中间表示的空间维度随着推理过程而缩小。
x = x + BN(Conv3x3(GELU(BN(Conv3x3(x)))))
这个模块使用的是粗细 token 组合的方式进行的设计。核心即为每个局部窗口学习专用的 carrier tokens参与到窗口内部和跨窗口的信息交互,基于 carrier token 对窗口之间的交互模式进行建模。由于其通过组合有固定窗口的注局部意力和随区域数量增加而线性增加的窗口 carrier tokens,分层注意力的计算复杂度几乎随输入图像分辨率线性增长。因此,它是捕获高分辨率特征的远距离关系的高效且有效的方法。
整体模块细节图如图 4 所示。在原始的局部注意力之上额外引入了与每个窗口相对应的 carrier token(CT)。这些 CT 用于总结对应的局部窗口。具体流程如下:
卷积位置编码->平均池化
这里是在 ImageNet1K 上初始训练 256 大小(I)300 个 epoch,后面测试了使用不同的尺寸和窗口大小(W)进行微调的结果。
FasterViT: Fast Vision Transformers with Hierarchical Attention
本文提出了一种 CNN 和 ViT 的混合架构,即 FasterViT。这样的混合架构可以快速生成高质量 token,然后基于 Transformer 块来进一步处理这些 token。其重点在于结合架构组合和高效的注意力模块的设计,从而优化 ViT 模型的计算效率,提高图像的吞吐率,加强对于高分辨率图像的适应能力。
模型设计
关于模型设计,作者写了一段具有参考价值的经验性分析,并在这些见解的指导下设计了所提架构,所有阶段均可从加速计算硬件受益。
x = x + BN(Conv3x3(GELU(BN(Conv3x3(x)))))
。分层注意力 HAT
这个模块使用的是粗细 token 组合的方式进行的设计。核心即为每个局部窗口学习专用的 carrier tokens参与到窗口内部和跨窗口的信息交互,基于 carrier token 对窗口之间的交互模式进行建模。由于其通过组合有固定窗口的注局部意力和随区域数量增加而线性增加的窗口 carrier tokens,分层注意力的计算复杂度几乎随输入图像分辨率线性增长。因此,它是捕获高分辨率特征的远距离关系的高效且有效的方法。
整体模块细节图如图 4 所示。在原始的局部注意力之上额外引入了与每个窗口相对应的 carrier token(CT)。这些 CT 用于总结对应的局部窗口。具体流程如下:
卷积位置编码->平均池化
的组合缩小特征空间维度。每个窗口对应生成 $L$ 个 CT($L << k$,k 是窗口的边长),共计 $n^2$ 个窗口,所以一共有 $n^2L$ 个 CT。实验性能
这里是在 ImageNet1K 上初始训练 256 大小(I)300 个 epoch,后面测试了使用不同的尺寸和窗口大小(W)进行微调的结果。