Hello! I've found a performance issue in train.py: train_ds = train_ds.batch(batch_size)(here) should be called before train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())(here), which would make your program more efficient.
To reproduce the behavior, you need to swap the order of train_ds = train_ds.batch(batch_size) and train_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count()) in train.py. Besides, you need to check the function _map_fn_train(here) called in train_ds.map() whether to be affected or not to make the changed code work properly. For example, if _map_fn_train needs data with shape (x, y, z) as its input before fix, it would require data with shape (batch_size, x, y, z).
Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.
Hello! I've found a performance issue in train.py:
train_ds = train_ds.batch(batch_size)
(here) should be called beforetrain_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
(here), which would make your program more efficient.Here is the tensorflow document to support it.
To reproduce the behavior, you need to swap the order of
train_ds = train_ds.batch(batch_size)
andtrain_ds = train_ds.map(_map_fn_train, num_parallel_calls=multiprocessing.cpu_count())
in train.py. Besides, you need to check the function_map_fn_train
(here) called intrain_ds.map()
whether to be affected or not to make the changed code work properly. For example, if_map_fn_train
needs data with shape (x, y, z) as its input before fix, it would require data with shape (batch_size, x, y, z).Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.