Closed Whj-cv closed 2 years ago
您好,非常感谢您的优秀工作。最近在学习ViG的代码,我有一些疑惑,这里为什么做池化处理呢? def forward(self, x, relative_pos=None): B, C, H, W = x.shape y = None if self.r > 1: y = F.avg_pool2d(x, self.r, self.r) y = y.reshape(B, C, -1, 1).contiguous() x = x.reshape(B, C, -1, 1).contiguous() edge_index = self.dilated_knn_graph(x, y, relative_pos) x = super(DyGraphConv2d, self).forward(x, edge_index, y) return x.reshape(B, -1, H, W).contiguous()
谢谢关注。这里是为了减少算节点之间距离的时候的计算量。
您好,非常感谢您的优秀工作。最近在学习ViG的代码,我有一些疑惑,这里为什么做池化处理呢? def forward(self, x, relative_pos=None): B, C, H, W = x.shape y = None if self.r > 1: y = F.avg_pool2d(x, self.r, self.r) y = y.reshape(B, C, -1, 1).contiguous()
x = x.reshape(B, C, -1, 1).contiguous() edge_index = self.dilated_knn_graph(x, y, relative_pos) x = super(DyGraphConv2d, self).forward(x, edge_index, y) return x.reshape(B, -1, H, W).contiguous()