hhk7734 / tensorflow-yolov4

YOLOv4 Implemented in Tensorflow 2.
MIT License
136 stars 75 forks source link

Enable mixed precision #54

Closed DannyGJdeJong closed 3 years ago

DannyGJdeJong commented 3 years ago

Is it possible to use TensorFlow's mixed precision for a performance boost on newer nvidia graphics cards?

from tensorflow.keras.mixed_precision import experimental as mixed_precision

policy = mixed_precision.Policy('mixed_float16')
mixed_precision.set_policy(policy)

yolo.make_model()

Right now it throws a TypeError when trying to do this because the outputs and inputs will not match in the existing model.

TypeError: Input 'y' of 'AddV2' Op has type float32 that does not match type float16 of argument 'x'.

Am I doing something wrong in enabling it this way or is it just not possible with the current implementation?

hhk7734 commented 3 years ago

https://github.com/hhk7734/tensorflow-yolov4/blob/be7b9305fcef1d1245948c26612c31f6d3212c5e/py_src/yolov4/model/yolov4.py#L98

https://github.com/hhk7734/tensorflow-yolov4/blob/be7b9305fcef1d1245948c26612c31f6d3212c5e/py_src/yolov4/model/yolov4.py#L128

remove the code above. I plan to modify the head part. (#47 ) But, I'm busy these days... So, if it is difficult to wait for the update, you will have to write your own head.

hhk7734 commented 3 years ago

On yolov4 v3.0.0, you can enable mixed precision.

yolo.predict or yolo.inference use yolo_diou_nms.

https://github.com/hhk7734/tensorflow-yolov4/blob/211afddf827d558fd7cfcef2cee21beda1c34ce4/py_src/yolov4/tf/__init__.py#L134-L136

But, this allows only np.float32 type array. So if you want to enable mixed precision, modify the code shown below.

    if candidates is not np.float32:
        candidates = candidates.astype(np.float32)
    pred_bboxes = self.yolo_diou_nms( 
        candidates=candidates, beta_nms=self.config.yolo_0.beta_nms 
    )