FlagOpen / FlagScale

FlagScale is a large model toolkit based on open-sourced projects.
Other
132 stars 40 forks source link

[Hetero] Support to set recompute for each stage and micro-batch #205

Closed heavyrain-lzy closed 1 week ago

heavyrain-lzy commented 2 weeks ago

Support to set recompute for each stage and micro-batch. An example is as follows:

recompute_granularity_per_stage_micro_batch:
      - [1, 254, 0, 2, 0]
      - [1, 254, 1, 2, 1]
      - [1, 254, 1, 2, 1]
    recompute_method_per_stage_micro_batch:
      - [1, 254, 0, 2, 0]
      - [1, 254, 0, 2, 0]
      - [1, 254, 1, 2, 1]
    recompute_num_layers_per_stage_micro_batch:
      - [1, 254, 2, 2, 2]
      - [1, 254, 1, 2, 1]
      - [1, 254, 2, 2, 2]

The sum of the first item of the sub-lists is equal to pipeline_paralell_size, that is to say, the 1+1+1 is equal to pipeline_paralell_size(3). The sub-list is organized by the format nums-micro-batch, config-flag except the first item. For this example, the 254 + 2 is equal to the num_micro_batch(256). For recompute_granularity_per_stage_micro_batch, the config-flag can be:

For recompute_method_per_stage_micro_batch, the config-flag can be:

For recompute_num_layers_per_stage_micro_batch, the config-flag can be 0 ~ nums_layers to set the recomputed layers