volcengine / verl

veRL: Volcano Engine Reinforcement Learning for LLM
https://verl.readthedocs.io/en/latest/index.html
Apache License 2.0
336 stars 21 forks source link

Why the `magatron_v4.patch` is needed? #14

Open hxdtest opened 6 days ago

hxdtest commented 6 days ago

https://github.com/volcengine/verl/blob/main/patches/megatron_v4.patch

For example:

-    tensor_shape = [seq_length, micro_batch_size, config.hidden_size]
+    tensor_shape = [seq_length, micro_batch_size, hidden_size]

what is the difference between hidden_size and config.hidden_size?

Why do you need `next_forward_k` and `backward_k` ?

- case 3

Many thanks !

PeterSH6 commented 5 days ago

Hi @hxdtest , the megatron_v4.patch is necessary for veRL for two main reasons:

  1. In veRL, we didn't initialize Megatron-LM with initialize_megatron, which initializes the global args. We only build the necessary process group by using mpu.initialize_model_parallel. Therefore, we have to delete the usage of get_args(). Case 4 is where we delete the get_args() and overlap_param_gather is set to False by default.
  2. We fix the vpp hanging problem when applying remove padding techniques in model training. Case 2 is used for fixing this

For case 1, config.hidden_size should be equal to hidden_size. False in case 3 could be removed as the default value is False and there seems to be no way to change its value in v0.4

hxdtest commented 3 days ago

Many thanks for your reply.

hxdtest commented 3 days ago

@PeterSH6
Have you tested verl with model size that's larger than 300B ? For example, have you tested llama3 405B ppo training on verl ?