YtongXie / CoTr

[MICCAI2021] CoTr: Efficiently Bridging CNN and Transformer for 3D Medical Image Segmentation
GNU General Public License v3.0
286 stars 45 forks source link

reference points解释 #1

Open 60wanjinbing opened 2 years ago

60wanjinbing commented 2 years ago

作者,你好,可以解释下这段代码的意思吗? def get_reference_points(spatial_shapes, valid_ratios, device): reference_pointslist = [] for lvl, (D, H, W) in enumerate(spatial_shapes):

        ref_d, ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, D_ - 0.5, D_, dtype=torch.float32, device=device),
                                             torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device),
                                             torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device))

        ref_d = ref_d.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * D_)
        ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 2] * H_)
        ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * W_)

        ref = torch.stack((ref_d, ref_x, ref_y), -1)   # D W H
        reference_points_list.append(ref)
    reference_points = torch.cat(reference_points_list, 1)
    reference_points = reference_points[:, :, None] * valid_ratios[:, None]
    return reference_points