alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.06k stars 353 forks source link

[BUG] All-reduce incorrectly skipped #273

Closed ZYHowell closed 2 years ago

ZYHowell commented 2 years ago

Currently we skip all-reduce if it is for gradient accumulation and rewritten (call them grad-acc all-reduce). However, after that, such an all-reduce can be merged with all-reduce not for grad-acc. Skip the merged one results in incorrect outputs, we should identify grad-acc all-reduce and only allow them to merge with grad-acc all-reduce.

A reproducible is:

class SkipAllReduceTest(PipelineBasicTest):

    def test_2_layer_bert(self):
        self.run_n_layer_bert(n_layers=2,
                             batch_size=4,
                             seq_len=4,
                             hidden_size=4,
                             num_heads=1,
                             pipeline_stage_mode="manual_gpipe",
                             forward_stage_layer_ids=[[0,], [1]],
                             overwrite_global_config_dict=dict(
                                sub_physical_mesh_shapes=[(1, 2)] * 2,
                                sub_logical_mesh_shapes=[(1, 2), (2, 1)],
                                submesh_autosharding_global_configs=[dict(force_batch_dim_to_mesh_dim=0)] * 2,
                                allow_all_gather=True,
                                use_scatter_gather=False
                             ))
merrymercy commented 2 years ago

Fixed by https://github.com/alpa-projects/tensorflow-alpa/pull/129