hqucms / weaver-core

Streamlined neural network training.
MIT License
44 stars 54 forks source link

Improve --data-split-num (minor tweak of #19) #20

Closed hqucms closed 2 months ago

hqucms commented 2 months ago

Minor tweak of https://github.com/hqucms/weaver-core/pull/19.

When input dataset is in Parquet format, the entire file must be first read into the memory and then select events from start_entry:stop_entry. This makes --fetch-step 0.01 still requires loading all datasets into the memory.

A simple solution to efficiently read Parquet dataset, without affecting the exiting function is to use --fetch-step 1 --data-split-num 100 (recently added to weaver) instead of --fetch-step 0.01. Therefore, this PR improves --data-split-num by handle the split in separate input groups, and allowing it to split more evenly.

For example, input datasets are separated into groups process_1: filelist_1, process_2:filelist_2, …. Now,

  • --data-split-num split data separately in each group so that the mixing is even among the groups.
  • In each group, --data-split-num can handle the case if N(file) is not divisible by the split number. Say, for N=5 and split_num=3, in the case of load_range=(0, 0.1), we have
    iter1: load file 0, 1; load_ranges=[(0, 0.1), (0, 0.0667)]
    iter2: load file 1, 2, 3; load_ranges=[(0.0667, 0.1), (0, 0.1), (0, 0.0333)]
    iter3: load file 3, 4; load_ranges=[(0.0333, 0.1), (0, 0.1)]

The initial impl. of --data-split-num (achieved inside the function _load_next) has been reverted.

colizz commented 2 months ago

Thanks a lot! A minor optimization has been implemented to improve performance by reordering the two for loops. Previously, n_div_d_sep_array was computed in the inner loop, which could be inefficient when the number of splits is large. ref: https://github.com/colizz/weaver-core/commit/7c8d9230630f72fb640e041ad59c0bb5c103305b

colizz commented 2 months ago

Here is a safety check of this functionality (from the debug output):

Iter 0:
  - ./datasets/JetClassII/Pythia/Res34P_0350.parquet with load_range=(0.0, 0.86)
  - ./datasets/JetClassII/Pythia/Res2P_0065.parquet with load_range=(0.0, 0.2)
  - ./datasets/JetClassII/Pythia/QCD_0185.parquet with load_range=(0.0, 0.28)
Iter 1:
  - ./datasets/JetClassII/Pythia/Res34P_0350.parquet with load_range=(0.86, 1.0)
  - ./datasets/JetClassII/Pythia/Res34P_0605.parquet with load_range=(0.0, 0.72)
  - ./datasets/JetClassII/Pythia/Res2P_0065.parquet with load_range=(0.2, 0.4)
  - ./datasets/JetClassII/Pythia/QCD_0185.parquet with load_range=(0.28, 0.56)
Iter 2:
  - ./datasets/JetClassII/Pythia/Res34P_0605.parquet with load_range=(0.72, 1.0)
  - ./datasets/JetClassII/Pythia/Res34P_0055.parquet with load_range=(0.0, 0.58)
  - ./datasets/JetClassII/Pythia/Res2P_0065.parquet with load_range=(0.4, 0.6)
  - ./datasets/JetClassII/Pythia/QCD_0185.parquet with load_range=(0.56, 0.84)
Iter 3:
  - ./datasets/JetClassII/Pythia/Res34P_0055.parquet with load_range=(0.58, 1.0)
  - ./datasets/JetClassII/Pythia/Res34P_0845.parquet with load_range=(0.0, 0.44)
  - ./datasets/JetClassII/Pythia/Res2P_0065.parquet with load_range=(0.6, 0.8)
  - ./datasets/JetClassII/Pythia/QCD_0185.parquet with load_range=(0.84, 1.0)
  - ./datasets/JetClassII/Pythia/QCD_0075.parquet with load_range=(0.0, 0.12)
Iter 4:
  - ./datasets/JetClassII/Pythia/Res34P_0845.parquet with load_range=(0.44, 1.0)
  - ./datasets/JetClassII/Pythia/Res34P_0750.parquet with load_range=(0.0, 0.3)
  - ./datasets/JetClassII/Pythia/Res2P_0065.parquet with load_range=(0.8, 1.0)
  - ./datasets/JetClassII/Pythia/QCD_0075.parquet with load_range=(0.12, 0.4)
Iter 5:
  - ./datasets/JetClassII/Pythia/Res34P_0750.parquet with load_range=(0.3, 1.0)
  - ./datasets/JetClassII/Pythia/Res34P_0810.parquet with load_range=(0.0, 0.16)
  - ./datasets/JetClassII/Pythia/Res2P_0070.parquet with load_range=(0.0, 0.2)
  - ./datasets/JetClassII/Pythia/QCD_0075.parquet with load_range=(0.4, 0.68)
Iter 6:
  - ./datasets/JetClassII/Pythia/Res34P_0810.parquet with load_range=(0.16, 1.0)
  - ./datasets/JetClassII/Pythia/Res34P_0025.parquet with load_range=(0.0, 0.02)
  - ./datasets/JetClassII/Pythia/Res2P_0070.parquet with load_range=(0.2, 0.4)
  - ./datasets/JetClassII/Pythia/QCD_0075.parquet with load_range=(0.68, 0.96)
  ......
hqucms commented 2 months ago

Thanks a lot! A minor optimization has been implemented to improve performance by reordering the two for loops. Previously, n_div_d_sep_array was computed in the inner loop, which could be inefficient when the number of splits is large. ref: colizz@7c8d923

Thanks a lot @colizz ! I cherry-picked the commit and pushed it to dev/custom_train_eval.