bubbliiiing / yolox-tf2

这是一个yolox-tf2的源码,可以用于训练自己的模型。
Apache License 2.0
72 stars 24 forks source link

upup,关于网络修改我想请问你一个问题。 #17

Closed 1ngram433 closed 2 years ago

1ngram433 commented 2 years ago

东西有点长,麻烦up耐心看一下,感谢!!

def loop_body(b, num_fg, loss_iou, loss_obj, loss_cls, loss_force, loss_angle):
        num_gt  = tf.cast(nlabel[b], tf.int32)
        gt_bboxes_per_image     = labels[b][:num_gt, :4]        # 每张图片的真实框
        gt_dis                  = labels[b][:num_gt,  4]        # 每张图片的真实距离,主要加的这里
        gt_classes              = labels[b][:num_gt,  5]        # 每张图片的真实类别
        bboxes_preds_per_image  = bbox_preds[b]                 # 每张图片的预测框
        obj_preds_per_image     = obj_preds[b]                  # 每张图片的置信度
        dis_preds_per_image     = force_preds[b]                # 每张图片的预测距离
        cls_preds_per_image     = cls_preds[b]                  # 每张图片的预测类别

        def f1():
            num_fg_img    = tf.cast(tf.constant(0), K.dtype(outputs))
            cls_target    = tf.cast(tf.zeros((0, num_classes)), K.dtype(outputs))
            reg_target    = tf.cast(tf.zeros((0, 4)), K.dtype(outputs))
            obj_target    = tf.cast(tf.zeros((total_num_anchors, 1)), K.dtype(outputs))
            dis_target    = tf.cast(tf.zeros((0, 1)), K.dtype(outputs))           # 主要加的这里
            fg_mask       = tf.cast(tf.zeros(total_num_anchors), tf.bool)
            return num_fg_img, cls_target, reg_target, obj_target, force_target, fg_mask

        def f2():
            gt_matched_classes, fg_mask, pred_ious_this_matching, matched_gt_inds, num_fg_img = get_assignments( 
                gt_bboxes_per_image, gt_classes, bboxes_preds_per_image, obj_preds_per_image, cls_preds_per_image,
                x_shifts, y_shifts, expanded_strides, num_classes, num_gt, total_num_anchors
            )
            reg_target  = tf.cast(tf.gather_nd(gt_bboxes_per_image, tf.reshape(matched_gt_inds, [-1, 1])), K.dtype(outputs))
            cls_target  = tf.cast(tf.one_hot(tf.cast(gt_matched_classes, tf.int32), num_classes) * tf.expand_dims(pred_ious_this_matching, -1), K.dtype(outputs))
            obj_target  = tf.cast(tf.expand_dims(fg_mask, -1), K.dtype(outputs))
            gt_matched_dis = tf.gather_nd(gt_dis, tf.reshape(matched_gt_inds, [-1, 1]))    # 主要加的这里
            dis_target  = tf.cast(tf.expand_dims(tf.cast(gt_matched_dis, tf.float32), -1), K.dtype(outputs))      
            # 这里是我增加的距离信息,把tf.float32改成tf.int32就跑通了
            return num_fg_img, cls_target, reg_target, obj_target, force_target, fg_mask

        num_fg_img, cls_target, reg_target, obj_target, force_target, angle_target, fg_mask = tf.cond(tf.equal(num_gt, 0), f1, f2)
        num_fg      += num_fg_img
        loss_iou    += K.sum(1 - box_ciou(reg_target, tf.boolean_mask(bboxes_preds_per_image, fg_mask)))
        loss_obj    += K.sum(K.binary_crossentropy(obj_target, obj_preds_per_image, from_logits=True))
        loss_cls    += K.sum(K.binary_crossentropy(cls_target, tf.boolean_mask(cls_preds_per_image, fg_mask), from_logits=True))
        loss_dis  += K.sum(tf.losses.mean_squared_error(dis_target, tf.boolean_mask(dis_preds_per_image, fg_mask)))               # 主要加的这里
        return b + 1, num_fg, loss_iou, loss_obj, loss_cls, loss_dis

    _, num_fg, loss_iou, loss_obj, loss_cls, loss_dis = tf.while_loop(lambda b,*args: b < tf.cast(tf.shape(outputs)[0], tf.int32), loop_body, [0, num_fg, loss_iou, loss_obj, loss_cls, loss_dis])

    num_fg      = tf.cast(tf.maximum(num_fg, 1), K.dtype(outputs))
    reg_weight  = 5.0
    loss        = reg_weight * loss_iou + loss_obj + loss_cls + loss_dis
    return loss / num_fg

针对你的源码我想加入距离的预测,代码如上,然而在yolo_training的get_losses函数中tf.cond出现这样的报错

TypeError: true_fn and false_fn arguments to tf.cond must have the same number, type, and overall structure of return values.

    true_fn output: [None, None, None, None, None, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/zeros_like:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/zeros_like_1:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/zeros_like_2:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/zeros_like_3:0' shape=(None, None) dtype=float32>, None, <tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x000001FD7DA1A7F0>]
    false_fn output: [None, None, None, None, None, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/Identity:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/Identity_1:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/Identity_2:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/Identity_3:0' shape=(None, None) dtype=float32>, None, <tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x000001FD7DA1A880>]

    Error details:
    The two structures don't have the same nested structure.

    First structure: type=list str=[None, None, None, None, None, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/zeros_like:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/zeros_like_1:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/zeros_like_2:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/zeros_like_3:0' shape=(None, None) dtype=float32>, None, <tensorflow.python.framework.indexed_slices.IndexedSlices object at 0x000001FD7DA1A7F0>]

    Second structure: type=list str=[None, None, None, None, None, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/Identity:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/Identity_1:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/Identity_2:0' shape=(None, None) dtype=float32>, <tf.Tensor 'gradient_tape/model_1/yolo_loss/while/gradients/model_1/yolo_loss/while/cond_grad/Identity_3:0' shape=(None, None) dtype=float32>, None, <tensorflow.python.framework.indexed_slices.IndexedSlices object at
0x000001FD7DA1A880>]

    More specifically: Incompatible CompositeTensor TypeSpecs: type=IndexedSlicesSpec str=IndexedSlicesSpec(TensorShape([None]), tf.float32, tf.int64, tf.int32, TensorShape([None])) vs. type=IndexedSlicesSpec str=IndexedSlicesSpec(TensorShape([None]), tf.float32, tf.int64, tf.int64, TensorShape([None]))
    Entire first structure:
    [., ., ., ., ., ., ., ., ., ., .]
    Entire second structure:
    [., ., ., ., ., ., ., ., ., ., .]

但是在f2()函数中把tf.float32改成tf.int32代码就能跑通了(不过我距离信息不太能是int😂),所以想询问一下您这是怎么回事,而且怎样才能换成float形式。感谢赐教!!!

dis_target  = tf.cast(tf.expand_dims(tf.cast(gt_matched_dis, tf.int32), -1), K.dtype(outputs))      # 这里是我增加的距离信息
1ngram433 commented 2 years ago

导师导师翻翻我😁

bubbliiiing commented 2 years ago

使用tf.Print(gt_matched_dis, [gt_matched_dis])看看gt_matched_dis 是什么

1ngram433 commented 2 years ago

up,使用tf.print(gt_matched_dis, [gt_matched_dis])后gt_matched_dis是这样的 [9.87 9.87 9.87 ... 9.87 9.87 9.87] [[9.87 9.87 9.87 ... 9.87 9.87 9.87]],就是我的一个距离标签。

然后dis_target是这样的,就是改成tf.int32能跑通,tf.float32就报错了😂 [[9] [9] [9] ... [9] [9] [9]] [[[9] [9] [9] ... [9] [9] [9]]

bubbliiiing commented 2 years ago

我感觉好像和这个没关系啊。这个是在gather后的,都没有进gather。 image

bubbliiiing commented 2 years ago

是不是这个有问题呀。 image

1ngram433 commented 2 years ago

ahhh,我也不太确定。不过现在换成pytorch跑通了,tensorflow问题好多啊😂。谢谢up主的耐心解答和开源的代码!

bubbliiiing commented 2 years ago

是的……所以pytorch的市场占有量越来越大……太难用了