Megvii-BaseDetection / BEVDepth

Official code for BEVDepth.
MIT License
684 stars 94 forks source link

版本修改问题 #190

Open iphanazi opened 6 months ago

iphanazi commented 6 months ago

请问作者v0.0.2版本中的 ` def training_step(self, batch): (sweepimgs, mats, , _, gt_boxes, gt_labels, depth_labels) = batch if torch.cuda.is_available(): for key, value in mats.items(): mats[key] = value.cuda() sweep_imgs = sweep_imgs.cuda() gt_boxes = [gt_box.cuda() for gt_box in gt_boxes] gt_labels = [gt_label.cuda() for gt_label in gt_labels] preds, depth_preds = self(sweep_imgs, mats) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): targets = self.model.module.get_targets(gt_boxes, gt_labels) detection_loss = self.model.module.loss(targets, preds) else: targets = self.model.get_targets(gt_boxes, gt_labels) detection_loss = self.model.loss(targets, preds)

    if len(depth_labels.shape) == 5:
        # only key-frame will calculate depth loss
        depth_labels = depth_labels[:, 0, ...]
    depth_loss = self.get_depth_loss(depth_labels.cuda(), depth_preds)
    self.log('detection_loss', detection_loss)
    self.log('depth_loss', depth_loss)
    return detection_loss + depth_loss`

为什么被修改了