Open iphanazi opened 6 months ago
请问作者v0.0.2版本中的 ` def training_step(self, batch): (sweepimgs, mats, , _, gt_boxes, gt_labels, depth_labels) = batch if torch.cuda.is_available(): for key, value in mats.items(): mats[key] = value.cuda() sweep_imgs = sweep_imgs.cuda() gt_boxes = [gt_box.cuda() for gt_box in gt_boxes] gt_labels = [gt_label.cuda() for gt_label in gt_labels] preds, depth_preds = self(sweep_imgs, mats) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): targets = self.model.module.get_targets(gt_boxes, gt_labels) detection_loss = self.model.module.loss(targets, preds) else: targets = self.model.get_targets(gt_boxes, gt_labels) detection_loss = self.model.loss(targets, preds)
if len(depth_labels.shape) == 5: # only key-frame will calculate depth loss depth_labels = depth_labels[:, 0, ...] depth_loss = self.get_depth_loss(depth_labels.cuda(), depth_preds) self.log('detection_loss', detection_loss) self.log('depth_loss', depth_loss) return detection_loss + depth_loss`
为什么被修改了
请问作者v0.0.2版本中的 ` def training_step(self, batch): (sweepimgs, mats, , _, gt_boxes, gt_labels, depth_labels) = batch if torch.cuda.is_available(): for key, value in mats.items(): mats[key] = value.cuda() sweep_imgs = sweep_imgs.cuda() gt_boxes = [gt_box.cuda() for gt_box in gt_boxes] gt_labels = [gt_label.cuda() for gt_label in gt_labels] preds, depth_preds = self(sweep_imgs, mats) if isinstance(self.model, torch.nn.parallel.DistributedDataParallel): targets = self.model.module.get_targets(gt_boxes, gt_labels) detection_loss = self.model.module.loss(targets, preds) else: targets = self.model.get_targets(gt_boxes, gt_labels) detection_loss = self.model.loss(targets, preds)
为什么被修改了