JF-D / Proteus

10 stars 3 forks source link

Fixed bs to comply with profiling #5

Open tareqmahmood opened 1 month ago

tareqmahmood commented 1 month ago

Confusion about n_macro_batch

Case 1

I was simulating for pp = 2 (n_micro_batch = 1), mp = 1, dp = 2 and I ran this command,

PYTHONPATH=/users/tareq/Proteus python megatron_gpt.py \
        -nlayer 12 \
        -hidden-size 768 \
        -nhead 12 \
        -seq-length 512 \
        -global-bs 16 \
        -n-macro-batch 1 \
        -cluster clusters/dgx1_v100_1ib/n1_g4.json \
        -ps pp \
        -pp-deg 2 \
        -mp-deg 1 \
        --reprofile \
        --profile-iters 10

I can see that you are profiling for ops like the output logging below:

profile: Embedding_0 ((8, 512), (40576, 768)) ((8, 512, 768),) cost: 0.157ms
profile: Embedding_2 ((8, 512), (512, 768)) ((8, 512, 768),) cost: 0.155ms
profile: Add_4 ((8, 512, 768), (8, 512, 768)) ((8, 512, 768),) cost: 0.095ms
profile: Dropout_6 ((8, 512, 768),) ((8, 512, 768),) cost: 0.084ms
profile: Permute_8 ((8, 512, 768),) ((512, 8, 768),) cost: 0.012ms
...

From its looks, profiling operations are happening for batch size = 8. This is understandable.

Case 2

However, when I simulate for n_micro_batch = 2 and I ran this command,

PYTHONPATH=/users/tareq/Proteus python megatron_gpt.py \
        -nlayer 12 \
        -hidden-size 768 \
        -nhead 12 \
        -seq-length 512 \
        -global-bs 16 \
        -n-macro-batch 2 \           # <----- single change here
        -cluster clusters/dgx1_v100_1ib/n1_g4.json \
        -ps pp \
        -pp-deg 2 \
        -mp-deg 1 \
        --reprofile \
        --profile-iters 10

It is still profiling for batch size = 8. I am pasting the output logging below:

profile: Embedding_0 ((8, 512), (40576, 768)) ((8, 512, 768),) cost: 0.143ms
profile: Embedding_2 ((8, 512), (512, 768)) ((8, 512, 768),) cost: 0.141ms
profile: Add_4 ((8, 512, 768), (8, 512, 768)) ((8, 512, 768),) cost: 0.095ms
profile: Dropout_6 ((8, 512, 768),) ((8, 512, 768),) cost: 0.084ms
profile: Permute_8 ((8, 512, 768),) ((512, 8, 768),) cost: 0.012ms
...

Question

Should not the profiling happen for batch size = 4 in the later case? This is also causing Proteus to estimate runtime $2\times$ the runtime of Megatron-LM.

Possible Fix

Updating bs made runtimes pretty close to Megatron-LM.