Hello! I've found a performance issue in your project: batch() should be called before map(), which could make your program more efficient. Here is the tensorflow document to support it.
Detailed description is listed below:
/train_scannet_grid.py: train_data.batch(FLAGS.batch_size, drop_remainder=True)(here) should be called before train_data.map(map_func=map_func, num_parallel_calls=dataset.num_threads)(here).
/train_scannet_grid.py: val_data.batch(FLAGS.batch_size, drop_remainder=True)(here) should be called before val_data.map(map_func=map_func, num_parallel_calls=dataset.num_threads)(here).
/test_scannet_grid.py: val_data.batch(FLAGS.batch_size, drop_remainder=True)(here) should be called before val_data.map(map_func=map_func, num_parallel_calls=dataset.num_threads)(here).
Besides, you need to check the function called in map()(e.g., map_func called in val_data.map(map_func=map_func, num_parallel_calls=dataset.num_threads)) whether to be affected or not to make the changed code work properly. For example, if map_func needs data with shape (x, y, z) as its input before fix, it would require data with shape (batch_size, x, y, z).
Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.
Hello! I've found a performance issue in your project:
batch()
should be called beforemap()
, which could make your program more efficient. Here is the tensorflow document to support it.Detailed description is listed below:
train_data.batch(FLAGS.batch_size, drop_remainder=True)
(here) should be called beforetrain_data.map(map_func=map_func, num_parallel_calls=dataset.num_threads)
(here).val_data.batch(FLAGS.batch_size, drop_remainder=True)
(here) should be called beforeval_data.map(map_func=map_func, num_parallel_calls=dataset.num_threads)
(here).val_data.batch(FLAGS.batch_size, drop_remainder=True)
(here) should be called beforeval_data.map(map_func=map_func, num_parallel_calls=dataset.num_threads)
(here).Besides, you need to check the function called in
map()
(e.g.,map_func
called inval_data.map(map_func=map_func, num_parallel_calls=dataset.num_threads)
) whether to be affected or not to make the changed code work properly. For example, ifmap_func
needs data with shape (x, y, z) as its input before fix, it would require data with shape (batch_size, x, y, z).Looking forward to your reply. Btw, I am very glad to create a PR to fix it if you are too busy.