alpa-projects / alpa

Training and serving large-scale neural networks with auto parallelization.
https://alpa.ai
Apache License 2.0
3.08k stars 357 forks source link

PipeshardParallel + GPT2 example fails with compile error and segmentation fault #863

Closed jaywonchung closed 1 year ago

jaywonchung commented 1 year ago

Please describe the bug I'm trying to use PipeshardParallel for the GPT2 example in examples/gpt2 (20debbe5f0ed4047d82ae615cb2c07b059498032) with Alpa v0.2.2 inside a Docker container. I'm on an RHEL node with four NVIDIA A40 GPUs.

Please describe the expected behavior

System information and environment

To Reproduce Steps to reproduce the behavior:

  1. Build docker image with docker/coreweave/run_alpa_infiniband.Dockerfile. All following commands done inside container.
  2. git clone --recursive https://github.com/alpa-projects/alpa.git
  3. cd alpa/examples/gpt2
  4. Edit run_clm_flax.py so that it uses PipeshardParallel instead of Zero2Parallel:
    method = alpa.PipeshardParallel(
        devices=None,
        num_micro_batches=training_args.num_micro_batches,  # 16 in this case
        default_auto_sharding_option=None,
        pipeline_schedule="1f1b",
        layer_option=None,
        stage_option="auto",
        stage_input_shardings=None,
    )
  5. pip install transformers datasets (transformers 4.25.1, datasets 2.8.0)
  6. export LD_LIBRARY_PATH=$LD_LIBRARY_PATH:$CONDA_PREFIX/lib/
  7. pip install tensorflow
  8. mkdir norwegian-gpt2 && python train_tokenizer.py && python create_config.py
  9. python3 run_clm_flax.py \
    --output_dir="./norwegian-gpt2" \
    --model_type="gpt2" \
    --config_name="./norwegian-gpt2" \
    --tokenizer_name="./norwegian-gpt2" \
    --dataset_name="oscar" \
    --dataset_config_name="unshuffled_deduplicated_no" \
    --do_train \
    --block_size="512" \
    --per_device_train_batch_size="32" \
    --num_micro_batches="16" \
    --dtype="float16" \
    --learning_rate="1e-3" --warmup_steps="1000" \
    --adam_beta1="0.9" --adam_beta2="0.98" --weight_decay="0.01" \
    --overwrite_output_dir \
    --num_train_epochs="20" \
    --logging_steps="20" \
    --save_steps="2500" \
    --eval_steps="2500"
Full error output (tqdm disabled) ``` INFO:__main__:***** Running training ***** INFO:__main__: Num examples = 1966029 INFO:__main__: Num Epochs = 20 INFO:__main__: Batch size per device (w. accumulation) = 32 INFO:__main__: Global train batch size (w. parallel & distributed) = 128 INFO:__main__: Total optimization steps = 307180 Initial compilation. This might take some minutes... -------------------- Automatic stage clustering -------------------- submesh_choices: ((1, 1), (1, 2), (1, 4)) - Profiling for submesh 2 (1, 4): - Generate all stage infos (Jaxpr -> HLO) - Compile all stages (CompileWorker pid=73427) 2023-01-19 21:55:05.286251: F external/org_tensorflow/tensorflow/compiler/xla/service/spmd/auto_sharding.cc:1465] Check failed: strategies->is_tuple || !strategies->leaf_vector.empty() %pad.38 = f16[8,512,2304]{2,1,0} pad(f16[8,512,768]{2,1,0} %reshape.1367, f16[] %constant.1168), padding=0_0x0_0x1536_0, metadata={op_name="parallelize(stage_0_1)/jit(main)/jit(merged)/jit(stage_0_1_compute2)/transpose(jvp(FlaxGPT2LMHeadModule))/transformer/h/11/attn/pad[padding_config=((0, 0, 0), (0, 0, 0), (1536, 0, 0))]" source_file="/opt/conda/envs/alpa/lib/python3.8/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py" source_line=211} does not have any valid strategies. (CompileWorker pid=73427) *** SIGABRT received at time=1674165305 on cpu 46 *** (CompileWorker pid=73427) PC: @ 0x7f698dbd000b (unknown) raise (CompileWorker pid=73427) @ 0x7f698deed420 537164224 (unknown) (CompileWorker pid=73427) @ 0x7f42c9fb207d 10592 xla::spmd::BuildStrategyAndCost() (CompileWorker pid=73427) @ 0x7f42cb6ce3b4 2368 xla::spmd::AutoSharding::Run() (CompileWorker pid=73427) @ 0x7f42cdf7f371 816 xla::HloPassPipeline::RunPassesInternal<>() (CompileWorker pid=73427) @ 0x7f42cdf7ffc5 448 xla::HloPassPipeline::Run() (CompileWorker pid=73427) @ 0x7f42ca52cf24 80 xla::HloPassInterface::Run() (CompileWorker pid=73427) @ 0x7f42ca536391 4128 xla::spmd::RunAutoShardingPass() (CompileWorker pid=73427) @ 0x7f42ca52215a 160 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() (CompileWorker pid=73427) @ 0x7f42ca392730 576 pybind11::cpp_function::dispatcher() (CompileWorker pid=73427) @ 0x4e1172 (unknown) PyCFunction_Call (CompileWorker pid=73427) @ 0x71a560 (unknown) (unknown) (CompileWorker pid=73427) [2023-01-19 21:55:05,324 E 73427 73427] logging.cc:361: *** SIGABRT received at time=1674165305 on cpu 46 *** (CompileWorker pid=73427) [2023-01-19 21:55:05,324 E 73427 73427] logging.cc:361: PC: @ 0x7f698dbd000b (unknown) raise (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f698deed420 537164224 (unknown) (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42c9fb207d 10592 xla::spmd::BuildStrategyAndCost() (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42cb6ce3b4 2368 xla::spmd::AutoSharding::Run() (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42cdf7f371 816 xla::HloPassPipeline::RunPassesInternal<>() (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42cdf7ffc5 448 xla::HloPassPipeline::Run() (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42ca52cf24 80 xla::HloPassInterface::Run() (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42ca536391 4128 xla::spmd::RunAutoShardingPass() (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42ca52215a 160 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x7f42ca392730 576 pybind11::cpp_function::dispatcher() (CompileWorker pid=73427) [2023-01-19 21:55:05,325 E 73427 73427] logging.cc:361: @ 0x4e1172 (unknown) PyCFunction_Call (CompileWorker pid=73427) [2023-01-19 21:55:05,326 E 73427 73427] logging.cc:361: @ 0x71a560 (unknown) (unknown) (CompileWorker pid=73427) Fatal Python error: Aborted (CompileWorker pid=73427) (CompileWorker pid=73427) Stack (most recent call first): (CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 344 in run_auto_sharding_pass (CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/pipeline_parallel/stage_profiling.py", line 161 in compile_stage_for_profiling (CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 466 in _resume_span (CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/function_manager.py", line 674 in actor_method_executor (CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/worker.py", line 763 in main_loop (CompileWorker pid=73427) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/workers/default_worker.py", line 231 in 2023-01-19 21:55:09,285 WARNING worker.py:1839 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffbde8ec2d3c9b71076befcba108000000 Worker ID: db6d86ae01244973e089c907ce3105bc69cd346479e7764b48feb453 Node ID: 207d08a537af2d27e0cc709647ab66e89e18b2763be4c8b5028126a3 Worker IP address: REDACTED_IP_ADDRESS Worker port: 10300 Worker PID: 73427 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors. WARNING:alpa.pipeline_parallel.stage_profiling:A Compile worker died unexpectedly: The actor died unexpectedly before finishing this task. class_name: CompileWorker actor_id: bde8ec2d3c9b71076befcba108000000 pid: 73427 namespace: alpa_default_space ip: REDACTED_IP_ADDRESS The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors. - Profile all stages cost[0, 1, 0]=0.036, max_n_succ_stage=4096, Mem: avail=39.475GB, peak=2.165GB, intermediate=0.000GB, init=0.348GB, as_config=((4, 1), {'force_batch_dim_to_mesh_dim': 0}) cost[0, 1, 2]=0.082, max_n_succ_stage=4096, Mem: avail=39.475GB, peak=2.491GB, intermediate=0.000GB, init=0.348GB, as_config=((1, 4), {'force_batch_dim_to_mesh_dim': 0}) cost[0, 1, 3]=0.033, max_n_succ_stage=4096, Mem: avail=39.475GB, peak=1.991GB, intermediate=0.000GB, init=0.348GB, as_config=((4, 1), {}) Profiling for submesh 2 (1, 4) takes 44.25 seconds Profiled costs are: [[[ inf inf inf inf] [0.03560379 inf 0.08236101 0.03326474]] [[ inf inf inf inf] [ inf inf inf inf]]] Profiled max_n_succ_stages are: [[[ -1 -1 -1 -1] [4096 -1 4096 4096]] [[ -1 -1 -1 -1] [ -1 -1 -1 -1]]] -------------------------------------------------- - Profiling for submesh 1 (1, 2): - Generate all stage infos (Jaxpr -> HLO) - Compile all stages - Profile all stages cost[0, 0, 0]=0.024, max_n_succ_stage=27, Mem: avail=39.475GB, peak=1.626GB, intermediate=1.295GB, init=0.494GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0}) cost[0, 0, 1]=0.029, max_n_succ_stage=23, Mem: avail=39.475GB, peak=1.708GB, intermediate=1.502GB, init=0.494GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0}) cost[0, 0, 2]=0.023, max_n_succ_stage=27, Mem: avail=39.475GB, peak=1.520GB, intermediate=1.295GB, init=0.494GB, as_config=((2, 1), {}) cost[0, 1, 0]=0.056, max_n_succ_stage=10, Mem: avail=39.475GB, peak=3.690GB, intermediate=3.061GB, init=0.695GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0}) cost[0, 1, 1]=0.066, max_n_succ_stage=9, Mem: avail=39.475GB, peak=3.931GB, intermediate=3.460GB, init=0.696GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0}) cost[1, 1, 0]=0.032, max_n_succ_stage=19, Mem: avail=39.475GB, peak=2.239GB, intermediate=1.766GB, init=0.489GB, as_config=((2, 1), {'force_batch_dim_to_mesh_dim': 0}) cost[0, 1, 2]=0.054, max_n_succ_stage=10, Mem: avail=39.475GB, peak=3.527GB, intermediate=3.025GB, init=0.695GB, as_config=((2, 1), {}) cost[1, 1, 1]=0.037, max_n_succ_stage=17, Mem: avail=39.475GB, peak=2.306GB, intermediate=1.958GB, init=0.490GB, as_config=((1, 2), {'force_batch_dim_to_mesh_dim': 0}) cost[1, 1, 2]=0.031, max_n_succ_stage=20, Mem: avail=39.475GB, peak=2.081GB, intermediate=1.730GB, init=0.489GB, as_config=((2, 1), {}) Profiling for submesh 1 (1, 2) takes 51.14 seconds Profiled costs are: [[[0.02386636 0.02937219 0.0229995 inf] [0.05573034 0.06621422 0.05409217 inf]] [[ inf inf inf inf] [0.03158776 0.03738411 0.03124457 inf]]] Profiled max_n_succ_stages are: [[[27 23 27 -1] [10 9 10 -1]] [[-1 -1 -1 -1] [19 17 20 -1]]] -------------------------------------------------- - Profiling for submesh 0 (1, 1): - Generate all stage infos (Jaxpr -> HLO) - Compile all stages - Profile all stages cost[0, 0, 1]=0.040, max_n_succ_stage=13, Mem: avail=39.475GB, peak=2.900GB, intermediate=2.511GB, init=0.987GB, as_config=((1, 1), {}) cost[0, 0, 0]=0.040, max_n_succ_stage=13, Mem: avail=39.475GB, peak=2.900GB, intermediate=2.511GB, init=0.987GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0}) cost[1, 1, 0]=0.056, max_n_succ_stage=9, Mem: avail=39.475GB, peak=4.118GB, intermediate=3.381GB, init=0.979GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0}) cost[1, 1, 1]=0.056, max_n_succ_stage=9, Mem: avail=39.475GB, peak=4.118GB, intermediate=3.381GB, init=0.979GB, as_config=((1, 1), {}) cost[0, 1, 0]=0.095, max_n_succ_stage=4, Mem: avail=39.475GB, peak=6.835GB, intermediate=5.892GB, init=1.391GB, as_config=((1, 1), {'force_batch_dim_to_mesh_dim': 0}) cost[0, 1, 1]=0.095, max_n_succ_stage=4, Mem: avail=39.475GB, peak=6.835GB, intermediate=5.892GB, init=1.391GB, as_config=((1, 1), {}) Profiling for submesh 0 (1, 1) takes 27.48 seconds Profiled costs are: [[[0.03979082 0.03967738 inf inf] [0.09511036 0.0951425 inf inf]] [[ inf inf inf inf] [0.05555977 0.05556859 inf inf]]] Profiled max_n_succ_stages are: [[[13 13 -1 -1] [ 4 4 -1 -1]] [[-1 -1 -1 -1] [ 9 9 -1 -1]]] -------------------------------------------------- Compute cost saved to: compute-cost-2023-01-19-21-57-01.npy ---------------------------------------------------------------------- Result forward_stage_layer_ids: [[0], [1]] Result mesh_shapes: [(1, 2), (1, 2)] Result logical_mesh_shapes: [(2, 1), (2, 1)] Result autosharding_option_dicts: [{}, {}] 2023-01-19 21:57:17,350 ERROR worker.py:400 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::MeshHostWorker.create_and_set_cross_mesh_communicators() (pid=79095, ip=REDACTED_IP_ADDRESS, repr=) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/device_mesh.py", line 411, in create_and_set_cross_mesh_communicators comms = g.get_nccl_collective_communicator(devices, "xla") File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/collective/collective_group/nccl_collective_group.py", line 478, in get_nccl_collective_communicator return self._get_nccl_collective_communicator(key, devices, lib) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/collective/collective_group/nccl_collective_group.py", line 455, in _get_nccl_collective_communicator comms = xla_extension.nccl_create_communicators_no_stream( AttributeError: module 'jaxlib.xla_extension' has no attribute 'nccl_create_communicators_no_stream' 2023-01-19 21:57:17,528 ERROR worker.py:400 -- Unhandled error (suppress with 'RAY_IGNORE_UNHANDLED_ERRORS=1'): ray::MeshHostWorker.create_and_set_cross_mesh_communicators() (pid=79094, ip=REDACTED_IP_ADDRESS, repr=) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/device_mesh.py", line 411, in create_and_set_cross_mesh_communicators comms = g.get_nccl_collective_communicator(devices, "xla") File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/collective/collective_group/nccl_collective_group.py", line 478, in get_nccl_collective_communicator return self._get_nccl_collective_communicator(key, devices, lib) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/collective/collective_group/nccl_collective_group.py", line 455, in _get_nccl_collective_communicator comms = xla_extension.nccl_create_communicators_no_stream( AttributeError: module 'jaxlib.xla_extension' has no attribute 'nccl_create_communicators_no_stream' (MeshHostWorker pid=79094) [1674165445.753032] [REDACTED_HOST_NAME:79094:1] debug.c:1289 UCX WARN ucs_debug_disable_signal: signal 8 was not set in ucs (MeshHostWorker pid=79094) [1674165445.753032] [REDACTED_HOST_NAME:79094:0] spinlock.c:29 UCX WARN ucs_recursive_spinlock_destroy() failed: busy (MeshHostWorker pid=79095) [REDACTED_HOST_NAME:79095:0:79370] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x8) (MeshHostWorker pid=79095) [REDACTED_HOST_NAME:79095:1:79368] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil)) (MeshHostWorker pid=79094) [REDACTED_HOST_NAME:79094:1:79400] Caught signal 11 (Segmentation fault: address not mapped to object at address 0x8) (MeshHostWorker pid=79094) [REDACTED_HOST_NAME:79094:0:79394] Caught signal 11 (Segmentation fault: address not mapped to object at address (nil)) (MeshHostWorker pid=79095) [1674165445.755438] [REDACTED_HOST_NAME:79095:0] debug.c:1289 UCX WARN ucs_debug_disable_signal: signal 11 was not set in ucs (MeshHostWorker pid=79095) [1674165445.755440] [REDACTED_HOST_NAME:79095:1] spinlock.c:29 UCX WARN ucs_recursive_spinlock_destroy() failed: busy (MeshHostWorker pid=79095) ==== backtrace (tid: 79368) ==== (MeshHostWorker pid=79095) 0 0x0000000000014420 __funlockfile() ???:0 (MeshHostWorker pid=79095) 1 0x00000000021c6ca8 xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream() :0 (MeshHostWorker pid=79095) 2 0x000000000215d53e xla::gpu::(anonymous namespace)::ExecuteThunks() gpu_executable.cc:0 (MeshHostWorker pid=79095) 3 0x000000000215ed40 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime() :0 (MeshHostWorker pid=79095) 4 0x00000000021635f8 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl() :0 (MeshHostWorker pid=79095) 5 0x000000000216425f xla::gpu::GpuExecutable::ExecuteAsyncOnStream() :0 (MeshHostWorker pid=79095) 6 0x000000000454c306 xla::Executable::ExecuteAsyncOnStreamWrapper() :0 (MeshHostWorker pid=79095) 7 0x00000000013183d0 xla::LocalExecutable::RunAsync() :0 (MeshHostWorker pid=79095) 8 0x0000000001318b40 xla::LocalExecutable::RunAsync() :0 (MeshHostWorker pid=79095) 9 0x00000000012df9ea xla::PjRtStreamExecutorExecutable::EnqueueExecution() :0 (MeshHostWorker pid=79095) 10 0x00000000012e0e21 xla::PjRtStreamExecutorExecutable::ExecuteHelper() :0 (MeshHostWorker pid=79095) 11 0x00000000012e3249 std::_Function_handler > const>, xla::ExecuteOptions const&, std::optional, std::allocator > > >&)::{lambda()#2}>::_M_invoke() pjrt_stream_executor_client.cc:0 (MeshHostWorker pid=79095) 12 0x00000000012ef468 xla::WorkerThread::WorkLoop() :0 (MeshHostWorker pid=79095) 13 0x00000000056a7005 tsl::(anonymous namespace)::PThread::ThreadFn() env.cc:0 (MeshHostWorker pid=79095) 14 0x0000000000008609 start_thread() ???:0 (MeshHostWorker pid=79095) 15 0x000000000011f133 clone() ???:0 (MeshHostWorker pid=79095) ================================= (MeshHostWorker pid=79095) *** SIGSEGV received at time=1674165446 on cpu 110 *** (MeshHostWorker pid=79095) PC: @ 0x7ef35fbfeca8 (unknown) xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream() (MeshHostWorker pid=79095) @ 0x7f1a218b6420 3728 (unknown) (MeshHostWorker pid=79095) @ 0x7ef35fb9553e 800 xla::gpu::(anonymous namespace)::ExecuteThunks() (MeshHostWorker pid=79095) @ 0x7ef35fb96d40 112 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime() (MeshHostWorker pid=79095) @ 0x7ef35fb9b5f8 2784 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl() (MeshHostWorker pid=79095) @ 0x7ef35fb9c25f 128 xla::gpu::GpuExecutable::ExecuteAsyncOnStream() (MeshHostWorker pid=79095) @ 0x7ef361f84306 1376 xla::Executable::ExecuteAsyncOnStreamWrapper() (MeshHostWorker pid=79095) @ 0x7ef35ed503d0 2432 xla::LocalExecutable::RunAsync() (MeshHostWorker pid=79095) @ 0x7ef35ed50b40 256 xla::LocalExecutable::RunAsync() (MeshHostWorker pid=79095) @ 0x7ef35ed179ea 2720 xla::PjRtStreamExecutorExecutable::EnqueueExecution() (MeshHostWorker pid=79095) @ 0x7ef35ed18e21 5360 xla::PjRtStreamExecutorExecutable::ExecuteHelper() (MeshHostWorker pid=79095) @ 0x7ef35ed1b249 240 std::_Function_handler<>::_M_invoke() (MeshHostWorker pid=79095) @ 0x7ef35ed27468 208 xla::WorkerThread::WorkLoop() (MeshHostWorker pid=79095) @ 0x7ef3630df005 80 tsl::(anonymous namespace)::PThread::ThreadFn() (MeshHostWorker pid=79095) @ 0x7f1a218aa609 (unknown) start_thread (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: *** SIGSEGV received at time=1674165446 on cpu 110 *** (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: PC: @ 0x7ef35fbfeca8 (unknown) xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7f1a218b6420 3728 (unknown) (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35fb9553e 800 xla::gpu::(anonymous namespace)::ExecuteThunks() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35fb96d40 112 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35fb9b5f8 2784 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35fb9c25f 128 xla::gpu::GpuExecutable::ExecuteAsyncOnStream() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef361f84306 1376 xla::Executable::ExecuteAsyncOnStreamWrapper() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed503d0 2432 xla::LocalExecutable::RunAsync() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed50b40 256 xla::LocalExecutable::RunAsync() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed179ea 2720 xla::PjRtStreamExecutorExecutable::EnqueueExecution() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed18e21 5360 xla::PjRtStreamExecutorExecutable::ExecuteHelper() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed1b249 240 std::_Function_handler<>::_M_invoke() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef35ed27468 208 xla::WorkerThread::WorkLoop() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7ef3630df005 80 tsl::(anonymous namespace)::PThread::ThreadFn() (MeshHostWorker pid=79095) [2023-01-19 21:57:26,478 E 79095 79368] logging.cc:361: @ 0x7f1a218aa609 (unknown) start_thread (MeshHostWorker pid=79095) Fatal Python error: Segmentation fault (MeshHostWorker pid=79095) (MeshHostWorker pid=79094) ==== backtrace (tid: 79394) ==== (MeshHostWorker pid=79094) 0 0x0000000000014420 __funlockfile() ???:0 (MeshHostWorker pid=79094) 1 0x00000000021c6ca8 xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream() :0 (MeshHostWorker pid=79094) 2 0x000000000215d53e xla::gpu::(anonymous namespace)::ExecuteThunks() gpu_executable.cc:0 (MeshHostWorker pid=79094) 3 0x000000000215ed40 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime() :0 (MeshHostWorker pid=79094) 4 0x00000000021635f8 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl() :0 (MeshHostWorker pid=79094) 5 0x000000000216425f xla::gpu::GpuExecutable::ExecuteAsyncOnStream() :0 (MeshHostWorker pid=79094) 6 0x000000000454c306 xla::Executable::ExecuteAsyncOnStreamWrapper() :0 (MeshHostWorker pid=79094) 7 0x00000000013183d0 xla::LocalExecutable::RunAsync() :0 (MeshHostWorker pid=79094) 8 0x0000000001318b40 xla::LocalExecutable::RunAsync() :0 (MeshHostWorker pid=79094) 9 0x00000000012df9ea xla::PjRtStreamExecutorExecutable::EnqueueExecution() :0 (MeshHostWorker pid=79094) 10 0x00000000012e0e21 xla::PjRtStreamExecutorExecutable::ExecuteHelper() :0 (MeshHostWorker pid=79094) 11 0x00000000012e3249 std::_Function_handler > const>, xla::ExecuteOptions const&, std::optional, std::allocator > > >&)::{lambda()#2}>::_M_invoke() pjrt_stream_executor_client.cc:0 (MeshHostWorker pid=79094) 12 0x00000000012ef468 xla::WorkerThread::WorkLoop() :0 (MeshHostWorker pid=79094) 13 0x00000000056a7005 tsl::(anonymous namespace)::PThread::ThreadFn() env.cc:0 (MeshHostWorker pid=79094) 14 0x0000000000008609 start_thread() ???:0 (MeshHostWorker pid=79094) 15 0x000000000011f133 clone() ???:0 (MeshHostWorker pid=79094) ================================= (MeshHostWorker pid=79094) *** SIGSEGV received at time=1674165446 on cpu 93 *** (MeshHostWorker pid=79094) PC: @ 0x7fc15fbfeca8 (unknown) xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream() (MeshHostWorker pid=79094) @ 0x7fe81f685420 3728 (unknown) (MeshHostWorker pid=79094) @ 0x7fc15fb9553e 800 xla::gpu::(anonymous namespace)::ExecuteThunks() (MeshHostWorker pid=79094) @ 0x7fc15fb96d40 112 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime() (MeshHostWorker pid=79094) @ 0x7fc15fb9b5f8 2784 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl() (MeshHostWorker pid=79094) @ 0x7fc15fb9c25f 128 xla::gpu::GpuExecutable::ExecuteAsyncOnStream() (MeshHostWorker pid=79094) @ 0x7fc161f84306 1376 xla::Executable::ExecuteAsyncOnStreamWrapper() (MeshHostWorker pid=79094) @ 0x7fc15ed503d0 2432 xla::LocalExecutable::RunAsync() (MeshHostWorker pid=79094) @ 0x7fc15ed50b40 256 xla::LocalExecutable::RunAsync() (MeshHostWorker pid=79094) @ 0x7fc15ed179ea 2720 xla::PjRtStreamExecutorExecutable::EnqueueExecution() (MeshHostWorker pid=79094) @ 0x7fc15ed18e21 5360 xla::PjRtStreamExecutorExecutable::ExecuteHelper() (MeshHostWorker pid=79094) @ 0x7fc15ed1b249 240 std::_Function_handler<>::_M_invoke() (MeshHostWorker pid=79094) @ 0x7fc15ed27468 208 xla::WorkerThread::WorkLoop() (MeshHostWorker pid=79094) @ 0x7fc1630df005 80 tsl::(anonymous namespace)::PThread::ThreadFn() (MeshHostWorker pid=79094) @ 0x7fe81f679609 (unknown) start_thread (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: *** SIGSEGV received at time=1674165446 on cpu 93 *** (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: PC: @ 0x7fc15fbfeca8 (unknown) xla::gpu::CrossMeshNcclAllReduceThunk::ExecuteOnStream() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fe81f685420 3728 (unknown) (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15fb9553e 800 xla::gpu::(anonymous namespace)::ExecuteThunks() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15fb96d40 112 xla::gpu::GpuExecutable::ExecuteThunksOrXlaRuntime() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15fb9b5f8 2784 xla::gpu::GpuExecutable::ExecuteAsyncOnStreamImpl() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15fb9c25f 128 xla::gpu::GpuExecutable::ExecuteAsyncOnStream() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc161f84306 1376 xla::Executable::ExecuteAsyncOnStreamWrapper() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed503d0 2432 xla::LocalExecutable::RunAsync() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed50b40 256 xla::LocalExecutable::RunAsync() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed179ea 2720 xla::PjRtStreamExecutorExecutable::EnqueueExecution() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed18e21 5360 xla::PjRtStreamExecutorExecutable::ExecuteHelper() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed1b249 240 std::_Function_handler<>::_M_invoke() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc15ed27468 208 xla::WorkerThread::WorkLoop() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fc1630df005 80 tsl::(anonymous namespace)::PThread::ThreadFn() (MeshHostWorker pid=79094) [2023-01-19 21:57:26,540 E 79094 79394] logging.cc:361: @ 0x7fe81f679609 (unknown) start_thread (MeshHostWorker pid=79094) Fatal Python error: Segmentation fault (MeshHostWorker pid=79094) 2023-01-19 21:57:34,541 WARNING worker.py:1839 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff84fa01cc7ea38596ba9e9dbe08000000 Worker ID: bf4eb1f8861d60f9d12849d496d767d75ae70600d62ca47b9d4101bd Node ID: 207d08a537af2d27e0cc709647ab66e89e18b2763be4c8b5028126a3 Worker IP address: REDACTED_IP_ADDRESS Worker port: 10343 Worker PID: 79095 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors. Traceback (most recent call last): File "run_clm_flax.py", line 902, in main() File "run_clm_flax.py", line 788, in main executable.sync() File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/pipeline_parallel/pipeshard_executable.py", line 401, in sync self.mesh_group.sync_workers() File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/device_mesh.py", line 2019, in sync_workers ray.get([w.sync.remote() for w in all_workers]) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/client_mode_hook.py", line 105, in wrapper return func(*args, **kwargs) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/worker.py", line 2291, in get raise value ray.exceptions.RayActorError: The actor died unexpectedly before finishing this task. class_name: MeshHostWorker actor_id: 84fa01cc7ea38596ba9e9dbe08000000 pid: 79095 namespace: alpa_default_space ip: REDACTED_IP_ADDRESS The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors. 2023-01-19 21:57:34,838 WARNING worker.py:1839 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffffa1ddb45e20cd16868664ad0208000000 Worker ID: 80c37f8712f00515006d806e5728697d4733b2dcde67767e582de8ad Node ID: 207d08a537af2d27e0cc709647ab66e89e18b2763be4c8b5028126a3 Worker IP address: REDACTED_IP_ADDRESS Worker port: 10342 Worker PID: 79094 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors. ```

As a side note, it would be great if there's a single Dockerfile to compile and run the Alpa HEAD commit.

jaywonchung commented 1 year ago

I tried the HEAD commit (20debbe5f0ed4047d82ae615cb2c07b059498032) and now the attribute error & segfault are gone. Just the identical Failed check error lingers.

I notice that the result of compilation (autosharding_option_dicts) is different.

Full error output (tqdm disabled) ``` INFO:__main__:***** Running training ***** INFO:__main__: Num examples = 1966029 INFO:__main__: Num Epochs = 20 INFO:__main__: Batch size per device (w. accumulation) = 32 INFO:__main__: Global train batch size (w. parallel & distributed) = 128 INFO:__main__: Total optimization steps = 307180 Initial compilation. This might take some minutes... -------------------- Automatic stage clustering -------------------- submesh_choices: ((1, 1), (1, 2), (1, 4)) - Profiling for submesh 2 (1, 4): - Generate all stage infos (Jaxpr -> HLO) - Compile all stages (CompileWorker pid=4290) 2023-01-20 00:20:20.658649: F external/org_tensorflow/tensorflow/compiler/xla/service/spmd/auto_sharding.cc:1465] Check failed: strategies->is_tuple || !strategies->leaf_vector.empty() %pad.38 = f16[8,512,2304]{2,1,0} pad(f16[8,512,768]{2,1,0} %reshape.1367, f16[] %constant.1168), padding=0_0x0_0x1536_0, metadata={op_name="parallelize(stage_0_1)/jit(main)/jit(stage_0_1_acc_grad_1)/jit(stage_0_1_acc_grad_10)/transpose(jvp(FlaxGPT2LMHeadModule))/transformer/h/11/attn/pad[padding_config=((0, 0, 0), (0, 0, 0), (1536, 0, 0))]" source_file="/opt/conda/envs/alpa/lib/python3.8/site-packages/transformers/models/gpt2/modeling_flax_gpt2.py" source_line=211} does not have any valid strategies. (CompileWorker pid=4290) *** SIGABRT received at time=1674174020 on cpu 32 *** (CompileWorker pid=4290) PC: @ 0x7f8765c1000b (unknown) raise (CompileWorker pid=4290) @ 0x7f8765f2d420 567026464 (unknown) (CompileWorker pid=4290) @ 0x7f60a85eb76f 10592 xla::spmd::BuildStrategyAndCost() (CompileWorker pid=4290) @ 0x7f60a9897864 2368 xla::spmd::AutoSharding::Run() (CompileWorker pid=4290) @ 0x7f60ac5ecd71 816 xla::HloPassPipeline::RunPassesInternal<>() (CompileWorker pid=4290) @ 0x7f60ac5ed9c5 448 xla::HloPassPipeline::Run() (CompileWorker pid=4290) @ 0x7f60a8ba01c4 80 xla::HloPassInterface::Run() (CompileWorker pid=4290) @ 0x7f60a8ba9631 4128 xla::spmd::RunAutoShardingPass() (CompileWorker pid=4290) @ 0x7f60a8b953fa 160 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() (CompileWorker pid=4290) @ 0x7f60a8a050a0 576 pybind11::cpp_function::dispatcher() (CompileWorker pid=4290) @ 0x4e1172 (unknown) PyCFunction_Call (CompileWorker pid=4290) @ 0x71a560 (unknown) (unknown) (CompileWorker pid=4290) [2023-01-20 00:20:20,695 E 4290 4290] logging.cc:361: *** SIGABRT received at time=1674174020 on cpu 32 *** (CompileWorker pid=4290) [2023-01-20 00:20:20,695 E 4290 4290] logging.cc:361: PC: @ 0x7f8765c1000b (unknown) raise (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f8765f2d420 567026464 (unknown) (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f60a85eb76f 10592 xla::spmd::BuildStrategyAndCost() (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f60a9897864 2368 xla::spmd::AutoSharding::Run() (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f60ac5ecd71 816 xla::HloPassPipeline::RunPassesInternal<>() (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f60ac5ed9c5 448 xla::HloPassPipeline::Run() (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f60a8ba01c4 80 xla::HloPassInterface::Run() (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f60a8ba9631 4128 xla::spmd::RunAutoShardingPass() (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f60a8b953fa 160 pybind11::cpp_function::initialize<>()::{lambda()#3}::_FUN() (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x7f60a8a050a0 576 pybind11::cpp_function::dispatcher() (CompileWorker pid=4290) [2023-01-20 00:20:20,697 E 4290 4290] logging.cc:361: @ 0x4e1172 (unknown) PyCFunction_Call (CompileWorker pid=4290) [2023-01-20 00:20:20,698 E 4290 4290] logging.cc:361: @ 0x71a560 (unknown) (unknown) (CompileWorker pid=4290) Fatal Python error: Aborted (CompileWorker pid=4290) (CompileWorker pid=4290) Stack (most recent call first): (CompileWorker pid=4290) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/shard_parallel/auto_sharding.py", line 345 in run_auto_sharding_pass (CompileWorker pid=4290) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/alpa/pipeline_parallel/stage_profiling.py", line 229 in compile_stage_for_profiling (CompileWorker pid=4290) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/util/tracing/tracing_helper.py", line 466 in _resume_span (CompileWorker pid=4290) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/function_manager.py", line 674 in actor_method_executor (CompileWorker pid=4290) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/worker.py", line 763 in main_loop (CompileWorker pid=4290) File "/opt/conda/envs/alpa/lib/python3.8/site-packages/ray/_private/workers/default_worker.py", line 231 in 2023-01-20 00:20:23,387 WARNING worker.py:1839 -- A worker died or was killed while executing a task by an unexpected system error. To troubleshoot the problem, check the logs for the dead worker. RayTask ID: ffffffffffffffff4567cdf9fd75979f23ab248001000000 Worker ID: 1340ed70025dc7ca71f2969ef0c825339049e7195f6ac4afdb251711 Node ID: 6ef4e266a8c2fccb5753580de83884f61a7c3f527254bd49fc0cf764 Worker IP address: REDACTED_IP_ADDRESS Worker port: 10005 Worker PID: 4290 Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors. WARNING:alpa.pipeline_parallel.stage_profiling:A Compile worker died unexpectedly: The actor died unexpectedly before finishing this task. class_name: CompileWorker actor_id: 4567cdf9fd75979f23ab248001000000 pid: 4290 namespace: alpa_default_space ip: REDACTED_IP_ADDRESS The actor is dead because its worker process has died. Worker exit type: SYSTEM_ERROR Worker exit detail: Worker unexpectedly exits with a connection error code 2. End of file. There are some potential root causes. (1) The process is killed by SIGKILL by OOM killer due to high memory usage. (2) ray stop --force is called. (3) The worker is crashed unexpectedly due to SIGSEGV or other unexpected errors. - Profile all stages (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) result[(0, 1, 2, 0), 0] = ModuleProfileResult(compute_cost=0.015, peak_memory=1.878 GB, invar_size=0.232 GB, outvar_size=1.646 GB, temp_buffer_size=0.000 GB, available_memory=39.475 GB) result[(0, 1, 2, 0), 1] = ModuleProfileResult(compute_cost=0.023, peak_memory=2.127 GB, invar_size=1.895 GB, outvar_size=0.250 GB, temp_buffer_size=0.232 GB, available_memory=39.475 GB) result[(0, 1, 2, 2), 0] = ModuleProfileResult(compute_cost=0.036, peak_memory=2.399 GB, invar_size=0.058 GB, outvar_size=2.243 GB, temp_buffer_size=0.097 GB, available_memory=39.475 GB) result[(0, 1, 2, 2), 1] = ModuleProfileResult(compute_cost=0.049, peak_memory=2.519 GB, invar_size=2.319 GB, outvar_size=0.076 GB, temp_buffer_size=0.199 GB, available_memory=39.475 GB) result[(0, 1, 2, 3), 0] = ModuleProfileResult(compute_cost=0.015, peak_memory=1.864 GB, invar_size=0.177 GB, outvar_size=1.592 GB, temp_buffer_size=0.096 GB, available_memory=39.475 GB) result[(0, 1, 2, 3), 1] = ModuleProfileResult(compute_cost=0.021, peak_memory=2.068 GB, invar_size=1.787 GB, outvar_size=0.195 GB, temp_buffer_size=0.281 GB, available_memory=39.475 GB) Profiling for submesh 2 (1, 4) takes 44.56 seconds -------------------------------------------------- - Profiling for submesh 1 (1, 2): - Generate all stage infos (Jaxpr -> HLO) - Compile all stages - Profile all stages (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) result[(0, 0, 1, 0), 0] = ModuleProfileResult(compute_cost=0.010, peak_memory=1.495 GB, invar_size=0.153 GB, outvar_size=1.298 GB, temp_buffer_size=0.044 GB, available_memory=39.475 GB) result[(0, 0, 1, 0), 1] = ModuleProfileResult(compute_cost=0.014, peak_memory=1.603 GB, invar_size=1.451 GB, outvar_size=0.153 GB, temp_buffer_size=0.153 GB, available_memory=39.475 GB) result[(0, 0, 1, 1), 0] = ModuleProfileResult(compute_cost=0.014, peak_memory=1.593 GB, invar_size=0.076 GB, outvar_size=1.508 GB, temp_buffer_size=0.009 GB, available_memory=39.475 GB) result[(0, 0, 1, 2), 0] = ModuleProfileResult(compute_cost=0.011, peak_memory=1.458 GB, invar_size=0.116 GB, outvar_size=1.298 GB, temp_buffer_size=0.044 GB, available_memory=39.475 GB) result[(0, 0, 1, 1), 1] = ModuleProfileResult(compute_cost=0.017, peak_memory=1.810 GB, invar_size=1.584 GB, outvar_size=0.076 GB, temp_buffer_size=0.226 GB, available_memory=39.475 GB) result[(0, 0, 1, 2), 1] = ModuleProfileResult(compute_cost=0.014, peak_memory=1.508 GB, invar_size=1.417 GB, outvar_size=0.116 GB, temp_buffer_size=0.091 GB, available_memory=39.475 GB) result[(0, 1, 1, 0), 0] = ModuleProfileResult(compute_cost=0.025, peak_memory=3.390 GB, invar_size=0.232 GB, outvar_size=3.061 GB, temp_buffer_size=0.097 GB, available_memory=39.475 GB) result[(0, 1, 1, 0), 1] = ModuleProfileResult(compute_cost=0.033, peak_memory=3.629 GB, invar_size=3.329 GB, outvar_size=0.268 GB, temp_buffer_size=0.300 GB, available_memory=39.475 GB) result[(0, 1, 1, 1), 0] = ModuleProfileResult(compute_cost=0.032, peak_memory=3.579 GB, invar_size=0.116 GB, outvar_size=3.460 GB, temp_buffer_size=0.003 GB, available_memory=39.475 GB) result[(0, 1, 1, 1), 1] = ModuleProfileResult(compute_cost=0.039, peak_memory=4.031 GB, invar_size=3.612 GB, outvar_size=0.152 GB, temp_buffer_size=0.419 GB, available_memory=39.475 GB) result[(0, 1, 1, 2), 0] = ModuleProfileResult(compute_cost=0.026, peak_memory=3.317 GB, invar_size=0.195 GB, outvar_size=3.025 GB, temp_buffer_size=0.097 GB, available_memory=39.475 GB) result[(0, 1, 1, 2), 1] = ModuleProfileResult(compute_cost=0.033, peak_memory=4.002 GB, invar_size=3.256 GB, outvar_size=0.231 GB, temp_buffer_size=0.745 GB, available_memory=39.475 GB) result[(1, 1, 1, 0), 0] = ModuleProfileResult(compute_cost=0.015, peak_memory=2.165 GB, invar_size=0.154 GB, outvar_size=1.766 GB, temp_buffer_size=0.244 GB, available_memory=39.475 GB) result[(1, 1, 1, 0), 1] = ModuleProfileResult(compute_cost=0.018, peak_memory=2.128 GB, invar_size=1.917 GB, outvar_size=0.154 GB, temp_buffer_size=0.208 GB, available_memory=39.475 GB) result[(1, 1, 1, 1), 0] = ModuleProfileResult(compute_cost=0.018, peak_memory=2.240 GB, invar_size=0.081 GB, outvar_size=1.958 GB, temp_buffer_size=0.201 GB, available_memory=39.475 GB) result[(1, 1, 1, 2), 0] = ModuleProfileResult(compute_cost=0.015, peak_memory=2.093 GB, invar_size=0.118 GB, outvar_size=1.730 GB, temp_buffer_size=0.244 GB, available_memory=39.475 GB) result[(1, 1, 1, 1), 1] = ModuleProfileResult(compute_cost=0.022, peak_memory=2.283 GB, invar_size=2.033 GB, outvar_size=0.081 GB, temp_buffer_size=0.244 GB, available_memory=39.475 GB) result[(1, 1, 1, 2), 1] = ModuleProfileResult(compute_cost=0.018, peak_memory=2.046 GB, invar_size=1.845 GB, outvar_size=0.118 GB, temp_buffer_size=0.198 GB, available_memory=39.475 GB) Profiling for submesh 1 (1, 2) takes 53.44 seconds -------------------------------------------------- - Profiling for submesh 0 (1, 1): - Generate all stage infos (Jaxpr -> HLO) - Compile all stages - Profile all stages (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) result[(0, 0, 0, 0), 1] = ModuleProfileResult(compute_cost=0.022, peak_memory=2.980 GB, invar_size=2.670 GB, outvar_size=0.153 GB, temp_buffer_size=0.311 GB, available_memory=39.475 GB) result[(0, 0, 0, 1), 0] = ModuleProfileResult(compute_cost=0.018, peak_memory=2.758 GB, invar_size=0.153 GB, outvar_size=2.517 GB, temp_buffer_size=0.088 GB, available_memory=39.475 GB) result[(0, 0, 0, 0), 0] = ModuleProfileResult(compute_cost=0.018, peak_memory=2.758 GB, invar_size=0.153 GB, outvar_size=2.517 GB, temp_buffer_size=0.088 GB, available_memory=39.475 GB) result[(0, 0, 0, 1), 1] = ModuleProfileResult(compute_cost=0.022, peak_memory=2.980 GB, invar_size=2.670 GB, outvar_size=0.153 GB, temp_buffer_size=0.311 GB, available_memory=39.475 GB) result[(0, 1, 0, 0), 0] = ModuleProfileResult(compute_cost=0.044, peak_memory=6.318 GB, invar_size=0.232 GB, outvar_size=5.892 GB, temp_buffer_size=0.193 GB, available_memory=39.475 GB) result[(0, 1, 0, 0), 1] = ModuleProfileResult(compute_cost=0.053, peak_memory=6.808 GB, invar_size=6.196 GB, outvar_size=0.304 GB, temp_buffer_size=0.612 GB, available_memory=39.475 GB) result[(1, 1, 0, 0), 0] = ModuleProfileResult(compute_cost=0.026, peak_memory=4.027 GB, invar_size=0.157 GB, outvar_size=3.381 GB, temp_buffer_size=0.489 GB, available_memory=39.475 GB) result[(0, 1, 0, 1), 0] = ModuleProfileResult(compute_cost=0.044, peak_memory=6.318 GB, invar_size=0.232 GB, outvar_size=5.892 GB, temp_buffer_size=0.193 GB, available_memory=39.475 GB) result[(0, 1, 0, 1), 1] = ModuleProfileResult(compute_cost=0.052, peak_memory=6.808 GB, invar_size=6.196 GB, outvar_size=0.304 GB, temp_buffer_size=0.612 GB, available_memory=39.475 GB) result[(1, 1, 0, 0), 1] = ModuleProfileResult(compute_cost=0.030, peak_memory=3.960 GB, invar_size=3.532 GB, outvar_size=0.157 GB, temp_buffer_size=0.422 GB, available_memory=39.475 GB) result[(1, 1, 0, 1), 0] = ModuleProfileResult(compute_cost=0.026, peak_memory=4.027 GB, invar_size=0.157 GB, outvar_size=3.381 GB, temp_buffer_size=0.489 GB, available_memory=39.475 GB) result[(1, 1, 0, 1), 1] = ModuleProfileResult(compute_cost=0.030, peak_memory=3.960 GB, invar_size=3.532 GB, outvar_size=0.157 GB, temp_buffer_size=0.422 GB, available_memory=39.475 GB) Profiling for submesh 0 (1, 1) takes 29.90 seconds -------------------------------------------------- Profile result saved to: profile-results-2023-01-20-00-22-22.npy ---------------------------------------------------------------------- Result forward_stage_layer_ids: [[0], [1]] Result mesh_shapes: [(1, 2), (1, 2)] Result logical_mesh_shapes: [(2, 1), (2, 1)] Result autosharding_option_dicts: [{'force_batch_dim_to_mesh_dim': 0}, {'force_batch_dim_to_mesh_dim': 0}] (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) (raylet) bash: /opt/conda/envs/alpa/lib/libtinfo.so.6: no version information available (required by bash) Initial compilation completed. Time elapsed: 157.79 s Step... 20 | Loss: 9.8741, Learning Rate: 0.00002, Throughput: 8134.91 token/s, 1.62 TFLOP/s ```
merrymercy commented 1 year ago

If you want to use advanced parallelization options. Please refer to this OPT example https://github.com/alpa-projects/alpa/tree/main/examples/opt_finetune and this branch https://github.com/alpa-projects/alpa/pull/858