AI-Hypercomputer / maxtext

A simple, performant and scalable Jax LLM!
Apache License 2.0
1.54k stars 295 forks source link

PGLE doesn't work for Tensor Parallelism #1005

Open wang2yn84 opened 3 weeks ago

wang2yn84 commented 3 weeks ago

We observed good overlap with FSDP + PGLE: Bq7PCuqyJbygSuL. Turning on and off PGLE makes a big difference here.

However, with TP + PGLE: 7nGeZQwG5Un84P3

There is no performance improvements. Computation and communications are completely exposed.

Here is the command: switch to lance-405b-clean branch

python3 MaxText/train.py MaxText/configs/models/gpu/llama3.1_405b.yml hardware=gpu run_name=maxtext-llama3.1-405b steps=10 max_target_length=4096 model_name=llama3.1-405b enable_checkpointing=false attention=cudnn_flash_te dataset_type=synthetic async_checkpointing=false base_output_directory=gs://lancewang-dev-supercomputer-testing/maxtext_gpu logits_dot_in_fp32=false use_iota_embed=true ici_tensor_parallelism=8 dcn_fsdp_parallelism=32 dcn_pipeline_parallelism=1 per_device_batch_size=1 num_layers_per_pipeline_stage=16 weight_dtype=bfloat16 remat_policy=save_qkv_proj profiler=xplane skip_first_n_steps_for_profiler=5 base_num_decoder_layers=126

Here are the xla flags: --xla_gpu_graph_level=0 --xla_gpu_enable_latency_hiding_scheduler=true --xla_gpu_enable_triton_gemm=false --xla_gpu_enable_highest_priority_async_stream=true --xla_gpu_all_reduce_combine_threshold_bytes=536870912 --xla_gpu_all_gather_combine_threshold_bytes=536870912 --xla_gpu_reduce_scatter_combine_threshold_bytes=536870912 --xla_gpu_enable_pipelined_all_gather=true --xla_gpu_enable_pipelined_reduce_scatter=true --xla_gpu_enable_pipelined_all_reduce=true --xla_gpu_enable_while_loop_double_buffering=true --xla_disable_hlo_passes=rematerialization --xla_gpu_enable_pgle_accuracy_checker=false --xla_gpu_enable_triton_softmax_fusion=false --xla_gpu_enable_all_gather_combine_by_dim=false --xla_gpu_enable_reduce_scatter_combine_by_dim=false

Here are the env variable: NCCL_SHIMNET_GUEST_CONFIG_CHECKER_CONFIG_FILE=/usr/local/nvidia/lib64/a3plus_guest_config.textproto NCCL_FASTRAK_PLUGIN_ACCEPT_TIMEOUT_MS=600000 JAX_ENABLE_PGLE=true JAX_REMOVE_CUSTOM_PARTITIONING_PTR_FROM_CACHE_KEY=true JAX_DEBUG_LOG_MODULES=compiler

The image we built on Oct 22nd.

reedwm commented 3 weeks ago

@Tixxx do you know what the issue is? I'm trying to reproduce this issue myself still.

Tixxx commented 2 weeks ago

I cannot access the screenshot above, it says page not found. Just a preliminary guess, the combiner threshold might introduce more data dependencies, so we usually tune it down if the collective is a combined one with a lot of data dependencies.

I have tried reproing using your command on maxtext main, but the yaml file doesnt exist for me. Would you be able to share a smaller model that can be easily repro'd on a single node? Thanks

wang2yn84 commented 2 weeks ago

Syncing in the chat.

On Wed, Nov 6, 2024 at 4:46 PM TJ Xu @.***> wrote:

I cannot access the screenshot above, it says page not found. Just a preliminary guess, the combiner threshold might introduce more data dependencies, so we usually tune it down if the collective is a combined one with a lot of data dependencies.

I have tried reproing using your command on maxtext main, but the yaml file doesnt exist for me. Would you be able to share a smaller model that can be easily repro'd on a single node? Thanks

— Reply to this email directly, view it on GitHub https://github.com/AI-Hypercomputer/maxtext/issues/1005#issuecomment-2461093339, or unsubscribe https://github.com/notifications/unsubscribe-auth/ADEGX4AXCCKXWBAI5DKQXCLZ7KZ67AVCNFSM6AAAAABQ7L5FSCVHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDINRRGA4TGMZTHE . You are receiving this because you authored the thread.Message ID: @.***>

-- Cheers, Lance