openxla / xla

A machine learning compiler for GPUs, CPUs, and ML accelerators
Apache License 2.0
2.7k stars 434 forks source link

call ptxas become defunct, cause xla hung #11824

Open zjjott opened 6 months ago

zjjott commented 6 months ago

I'm running Llama-2-1.7b-hf +fsdp+xla but process show 523777 517263 0 80 0 - 0 - 10:22 ? 00:00:00 [ptxas] <defunct> I have using gdb to debug process: 517263 ,showing this backtrace:

0x00007f121032bb4d in read () from /usr/lib64/libc.so.6
#0  0x00007f121032bb4d in read () from /usr/lib64/libc.so.6
#1  0x00007f121032a88b in __spawnix () from /usr/lib64/libc.so.6
#2  0x00007f121032afef in __spawni () from /usr/lib64/libc.so.6
#3  0x00007f121032a6ab in posix_spawnp@@GLIBC_2.15 () from /usr/lib64/libc.so.6
#4  0x00007f1062914a4d in tsl::SubProcess::Start() () from /opt/conda/lib/python3.8/site-packages/_XLAC.cpython-38-x86_64-linux-gnu.so
#5  0x00007f106290ca7f in stream_executor::CompileGpuAsmUsingPtxAs(int, int, char const*, stream_executor::GpuAsmOpts, bool) () from /opt/conda/lib/python3.8/site-packages/_XLAC.cpython-38-x86_64-linux-gnu.so
#6  0x00007f105c63528c in xla::gpu::NVPTXCompiler::CompileGpuAsmOrGetCachedResult(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, stream_executor::CudaComputeCapability, xla::HloModuleConfig const&, std::basic_string_view<char, std::char_traits<char> >, bool, xla::Compiler::CompileOptions const&) () from /opt/conda/lib/python3.8/site-packages/_XLAC.cpython-38-x86_64-linux-gnu.so
#7  0x00007f105c637f31 in xla::gpu::NVPTXCompiler::CompileTargetBinary(xla::HloModuleConfig const&, llvm::Module*, std::variant<stream_executor::CudaComputeCapability, stream_executor::RocmComputeCapability>, bool, xla::HloModule const*, xla::Compiler::CompileOptions const&) () from /opt/conda/lib/python3.8/site-packages/_XLAC.cpython-38-x86_64-linux-gnu.so
#8  0x00007f105c792f99 in xla::gpu::GpuCompiler::CompileSingleModule(xla::HloModuleConfig const&, std::variant<stream_executor::CudaComputeCapability, stream_executor::RocmComputeCapability>, xla::HloModule const*, llvm::Module*, bool, xla::Compiler::CompileOptions const&, std::optional<int>) () from /opt/conda/lib/python3.8/site-packages/_XLAC.cpython-38-x86_64-linux-gnu.so
#9  0x00007f105c79ac8e in xla::gpu::GpuCompiler::CompileToTargetBinary(xla::HloModuleConfig const&, llvm::Module*, std::variant<stream_executor::CudaComputeCapability, stream_executor::RocmComputeCapability>, stream_executor::StreamExecutor*, xla::Compiler::CompileOptions const&, xla::HloModule const*) () from /opt/conda/lib/python3.8/site-packages/_XLAC.cpython-38-x86_64-linux-gnu.so

using script:

https://gist.github.com/zjjott/9d26b31f99c5aaad7db6417a5bcc3ff9

running script:

DATASET_PATH=PATH_OF_alpaca_data.json
PRETRAINED_MODEL_DIR=PATH_OF_Llama-2-1.7b-hf
PJRT_DEVICE=CUDA
PER_DEVICE_TRAIN_BATCH_SIZE=8
GPU_NUM_DEVICES=8
export XLA_GPU_MEMORY_FRACTION=0.7
export XLA_GPU_MEMORY_PREALLOCATE=false
export XLA_GPU_MEMORY_ALLOCATOR_KIND=3
export PJRT_DEVICE=CUDA
export PJRT_ALLOCATOR_PREALLOCATE=false
export PJRT_ALLOCATOR_FRACTION=0.7
export PJRT_ALLOCATOR_CUDA_ASYNC=true
export TF_CPP_MIN_LOG_LEVEL=0
export TF_CPP_VMODULE="hlo_pass_pipeline=5,cuda_asm_compiler=5,nvptx_compiler=5"
export XLA_FLAGS="--xla_disable_hlo_passes=rematerialization,gpu-convert-async-collectives-to-sync"
torchrun --nnodes 1 --nproc-per-node 8 llama_benchmark.py \
    --not_save_model \
    --dataset_path $DATASET_PATH \
    --config_name $PRETRAINED_MODEL_DIR \
    --tokenizer_name $PRETRAINED_MODEL_DIR \
    --num_train_epochs 6 \
    --block_size 512 \
    --learning_rate 2e-5 \
    --weight_decay 0.0 \
    --warmup_ratio 0.03 \
    --lr_scheduler_type linear \
    --per_device_train_batch_size $PER_DEVICE_TRAIN_BATCH_SIZE \
    --seed 42 \
    --preprocessing_num_workers 4 \
    --dataloader_num_workers 8 \
    --ignore_mismatched_sizes \
    --ignore_dryrun_on_load_strategy \
    --output_dir ./outputs \
    --random_log_n_training_samples 0 \
    --logging_steps 10 \
    --report_to all --distributed_method fsdp --using_xla

commit version(I upgrade from January version to lastest,but also have this issue): pytorch: 7cd7a7aa8e0942da627095b23b94dc89f5a54943 torchxla: 58a412c openxla: 1acf05e

asm_compiler.cc:234] Using /usr/local/cuda/bin/ptxas with version 11.8.89 cuda: 11.8 Driver Version: 470.82.01 device: A100

ptx content:

cat /tmp/tempfile-gpuxdn011071209191.sa128-6d8cb21a-41448-616cf7d8789e7
//
// Generated by LLVM NVPTX Back-End
//

.version 7.4
.target sm_80
.address_size 64

.visible .global .align 128 .b8 buffer_for_constant_1_0[64];

seems execute ptxas by manual is success

cheshire commented 6 months ago

seems execute ptxas by manual is success

Could you turn on logging (with TF_CPP_VMODULE=asm_compiler=5 TF_CPP_MIN_LOG_LEVEL=0) to find out exact invocation? Your manual invocation is probably missing the optimization flag (-O iirc?).

If ptxas is hanging, it should hang for the manual invocation as well.

zjjott commented 6 months ago

seems execute ptxas by manual is success

Could you turn on logging (with TF_CPP_VMODULE=asm_compiler=5 TF_CPP_MIN_LOG_LEVEL=0) to find out exact invocation? Your manual invocation is probably missing the optimization flag (-O iirc?).

If ptxas is hanging, it should hang for the manual invocation as well.

yes, I turn on logging with (with TF_CPP_VMODULE=asm_compiler=5 TF_CPP_MIN_LOG_LEVEL=0) , ptxax execute with some /tmp path, but temp file content seems good @cheshire