Open 60wanjinbing opened 2 years ago
作者,你好,可以解释下这段代码的意思吗? def get_reference_points(spatial_shapes, valid_ratios, device): reference_pointslist = [] for lvl, (D, H, W) in enumerate(spatial_shapes):
ref_d, ref_y, ref_x = torch.meshgrid(torch.linspace(0.5, D_ - 0.5, D_, dtype=torch.float32, device=device), torch.linspace(0.5, H_ - 0.5, H_, dtype=torch.float32, device=device), torch.linspace(0.5, W_ - 0.5, W_, dtype=torch.float32, device=device)) ref_d = ref_d.reshape(-1)[None] / (valid_ratios[:, None, lvl, 0] * D_) ref_y = ref_y.reshape(-1)[None] / (valid_ratios[:, None, lvl, 2] * H_) ref_x = ref_x.reshape(-1)[None] / (valid_ratios[:, None, lvl, 1] * W_) ref = torch.stack((ref_d, ref_x, ref_y), -1) # D W H reference_points_list.append(ref) reference_points = torch.cat(reference_points_list, 1) reference_points = reference_points[:, :, None] * valid_ratios[:, None] return reference_points
作者,你好,可以解释下这段代码的意思吗? def get_reference_points(spatial_shapes, valid_ratios, device): reference_pointslist = [] for lvl, (D, H, W) in enumerate(spatial_shapes):