JayYip / m3tl

BERT for Multitask Learning
https://jayyip.github.io/m3tl/
Apache License 2.0
545 stars 125 forks source link

Performance issue in m3tl/read_write_tfrecord.py #115

Closed DLPerf closed 1 year ago

DLPerf commented 1 year ago

Hello! Our static bug checker has found a performance issue in m3tl/read_write_tfrecord.py: reshape_tensors_in_dataset is repeatedly called in a for loop, but there is a tf.function decorated function _reshape_tensor defined and called in reshape_tensors_in_dataset.

In that case, when reshape_tensors_in_dataset is called in a loop, the function _reshape_tensor will create a new graph every time, and that can trigger tf.function retracing warning.

Here is the tensorflow document to support it.

Briefly, for better efficiency, it's better to use:

@tf.function
def inner():
    pass

def outer():
    inner()  

than:

def outer():
    @tf.function
    def inner():
        pass
    inner()

Looking forward to your reply.

JayYip commented 1 year ago

Hi,

Thanks for the update. Please create a PR to fix this issue.

On Fri, Feb 24, 2023, 2:50 PM DLPerf @.***> wrote:

Hello! Our static bug checker has found a performance issue in m3tl/read_write_tfrecord.py: reshape_tensors_in_dataset https://github.com/JayYip/m3tl/blob/a948cc90017ec03b00a3496bab742e0ad8887952/m3tl/read_write_tfrecord.py#L529 is repeatedly called in a for loop, but there is a tf.function decorated function _reshape_tensor https://github.com/JayYip/m3tl/blob/a948cc90017ec03b00a3496bab742e0ad8887952/m3tl/read_write_tfrecord.py#L365 defined and called in reshape_tensors_in_dataset https://github.com/JayYip/m3tl/blob/a948cc90017ec03b00a3496bab742e0ad8887952/m3tl/read_write_tfrecord.py#L351 .

In that case, when reshape_tensors_in_dataset is called in a loop, the function _reshape_tensor will create a new graph every time, and that can trigger tf.function retracing warning.

Here is the tensorflow document https://tensorflow.google.cn/guide/function#tracing to support it.

Briefly, for better efficiency, it's better to use:

@tf.functiondef inner(): pass def outer(): inner()

than:

def outer(): @tf.function def inner(): pass inner()

Looking forward to your reply. Btw, I am glad to create a PR to fix it if you are too busy.

— Reply to this email directly, view it on GitHub https://github.com/JayYip/m3tl/issues/115, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADS2OTFTXDF4ZM5YOLTNWODWZBKZVANCNFSM6AAAAAAVGRGA4Q . You are receiving this because you are subscribed to this thread.Message ID: @.***>