Open danieldk opened 2 weeks ago
@alexm-neuralmagic any idea? I don't know how thoroughly small groupsizes are tested, this may simply be a tensor sharding issue.
The code for guarding sizes is in python layer of GPTQLinearMethod
, my guess is that one of the weight matrices get too skinny with tp=4 and this is not caught in the guard, but we should check + perhaps move these guards into the c++
I will try to reproduce
I have a similar Marlin-related illegal memory access error with tp=2
on two Runpod 4090s using the vllm/vllm-openai:v0.5.0.post1
Docker image. The error occurs during startup. I can confirm that the issue doesn't occur with desc_act=False
. You can replicate with these CLI options:
--model TheBloke/Llama-2-70B-GPTQ --revision gptq-4bit-32g-actorder_True --dtype float16 --tensor-parallel-size 2 --max-model-len 2048 --enable-chunked-prefill --gpu-memory-utilization 0.95 --max-num-seqs 2
INFO 06-23 19:53:05 api_server.py:177] vLLM API server version 0.5.0.post1
INFO 06-23 19:53:05 api_server.py:178] args: Namespace(host='0.0.0.0', port=3000, uvicorn_log_level='info', allow_credentials=False, allowed_origins=['*'], allowed_methods=['*'], allowed_headers=['*'], api_key=None, lora_modules=None, chat_template=None, response_role='assistant', ssl_keyfile=None, ssl_certfile=None, ssl_ca_certs=None, ssl_cert_reqs=0, root_path=None, middleware=[], model='TheBloke/Llama-2-70B-GPTQ', tokenizer=None, skip_tokenizer_init=False, revision='gptq-4bit-32g-actorder_True', code_revision=None, tokenizer_revision=None, tokenizer_mode='auto', trust_remote_code=False, download_dir=None, load_format='auto', dtype='float16', kv_cache_dtype='auto', quantization_param_path=None, max_model_len=2048, guided_decoding_backend='outlines', distributed_executor_backend=None, worker_use_ray=False, pipeline_parallel_size=1, tensor_parallel_size=2, max_parallel_loading_workers=None, ray_workers_use_nsight=False, block_size=16, enable_prefix_caching=False, disable_sliding_window=False, use_v2_block_manager=False, num_lookahead_slots=0, seed=0, swap_space=4, gpu_memory_utilization=0.95, num_gpu_blocks_override=None, max_num_batched_tokens=None, max_num_seqs=2, max_logprobs=20, disable_log_stats=False, quantization=None, rope_scaling=None, rope_theta=None, enforce_eager=False, max_context_len_to_capture=None, max_seq_len_to_capture=8192, disable_custom_all_reduce=False, tokenizer_pool_size=0, tokenizer_pool_type='ray', tokenizer_pool_extra_config=None, enable_lora=False, max_loras=1, max_lora_rank=16, lora_extra_vocab_size=256, lora_dtype='auto', long_lora_scaling_factors=None, max_cpu_loras=None, fully_sharded_loras=False, device='auto', image_input_type=None, image_token_id=None, image_input_shape=None, image_feature_size=None, image_processor=None, image_processor_revision=None, disable_image_processor=False, scheduler_delay_factor=0.0, enable_chunked_prefill=True, speculative_model=None, num_speculative_tokens=None, speculative_max_model_len=None, speculative_disable_by_batch_size=None, ngram_prompt_lookup_max=None, ngram_prompt_lookup_min=None, model_loader_extra_config=None, preemption_mode=None, served_model_name=None, qlora_adapter_name_or_path=None, engine_use_ray=False, disable_log_requests=False, max_log_len=None)
/usr/local/lib/python3.10/dist-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
INFO 06-23 19:53:05 gptq_marlin.py:133] The model is convertible to gptq_marlin during runtime. Using gptq_marlin kernel.
2024-06-23 19:53:07,810 WARNING utils.py:580 -- Detecting docker specified CPUs. In previous versions of Ray, CPU detection in containers was incorrect. Please ensure that Ray has enough CPUs allocated. As a temporary workaround to revert to the prior behavior, set `RAY_USE_MULTIPROCESSING_CPU_COUNT=1` as an env var before starting Ray. Set the env var: `RAY_DISABLE_DOCKER_CPU_WARNING=1` to mute this warning.
2024-06-23 19:53:07,811 WARNING utils.py:592 -- Ray currently does not support initializing Ray with fractional cpus. Your num_cpus will be truncated from 54.4 to 54.
2024-06-23 19:53:07,945 INFO worker.py:1753 -- Started a local Ray instance.
INFO 06-23 19:53:08 config.py:623] Defaulting to use mp for distributed inference
INFO 06-23 19:53:08 config.py:707] Chunked prefill is enabled (EXPERIMENTAL).
INFO 06-23 19:53:08 llm_engine.py:161] Initializing an LLM engine (v0.5.0.post1) with config: model='TheBloke/Llama-2-70B-GPTQ', speculative_config=None, tokenizer='TheBloke/Llama-2-70B-GPTQ', skip_tokenizer_init=False, tokenizer_mode=auto, revision=gptq-4bit-32g-actorder_True, rope_scaling=None, rope_theta=None, tokenizer_revision=gptq-4bit-32g-actorder_True, trust_remote_code=False, dtype=torch.float16, max_seq_len=2048, download_dir=None, load_format=LoadFormat.AUTO, tensor_parallel_size=2, disable_custom_all_reduce=False, quantization=gptq_marlin, enforce_eager=False, kv_cache_dtype=auto, quantization_param_path=None, device_config=cuda, decoding_config=DecodingConfig(guided_decoding_backend='outlines'), seed=0, served_model_name=TheBloke/Llama-2-70B-GPTQ)
[1;36m(VllmWorkerProcess pid=3497)[0;0m INFO 06-23 19:53:14 multiproc_worker_utils.py:215] Worker ready; awaiting tasks
INFO 06-23 19:53:14 utils.py:637] Found nccl from library libnccl.so.2
[1;36m(VllmWorkerProcess pid=3497)[0;0m INFO 06-23 19:53:14 utils.py:637] Found nccl from library libnccl.so.2
[1;36m(VllmWorkerProcess pid=3497)[0;0m INFO 06-23 19:53:14 pynccl.py:63] vLLM is using nccl==2.20.5
INFO 06-23 19:53:14 pynccl.py:63] vLLM is using nccl==2.20.5
Traceback (most recent call last):
File "/usr/lib/python3.10/multiprocessing/resource_tracker.py", line 209, in main
cache[rtype].remove(name)
KeyError: '/psm_61043e17'
INFO 06-23 19:53:15 custom_all_reduce_utils.py:179] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
[1;36m(VllmWorkerProcess pid=3497)[0;0m INFO 06-23 19:53:15 custom_all_reduce_utils.py:179] reading GPU P2P access cache from /root/.config/vllm/gpu_p2p_access_cache_for_0,1.json
[1;36m(VllmWorkerProcess pid=3497)[0;0m WARNING 06-23 19:53:15 custom_all_reduce.py:175] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
WARNING 06-23 19:53:15 custom_all_reduce.py:175] Custom allreduce is disabled because your platform lacks GPU P2P capability or P2P test failed. To silence this warning, specify disable_custom_all_reduce=True explicitly.
INFO 06-23 19:53:16 weight_utils.py:218] Using model weights format ['*.safetensors']
INFO 06-23 19:53:16 weight_utils.py:261] No model.safetensors.index.json found in remote.
[1;36m(VllmWorkerProcess pid=3497)[0;0m INFO 06-23 19:53:16 weight_utils.py:218] Using model weights format ['*.safetensors']
[1;36m(VllmWorkerProcess pid=3497)[0;0m INFO 06-23 19:53:16 weight_utils.py:261] No model.safetensors.index.json found in remote.
INFO 06-23 19:53:41 model_runner.py:160] Loading model weights took 19.1485 GB
[1;36m(VllmWorkerProcess pid=3497)[0;0m INFO 06-23 19:53:42 model_runner.py:160] Loading model weights took 19.1485 GB
[rank0]: Traceback (most recent call last):
[rank0]: File "/usr/lib/python3.10/runpy.py", line 196, in _run_module_as_main
[rank0]: return _run_code(code, main_globals, None,
[rank0]: File "/usr/lib/python3.10/runpy.py", line 86, in _run_code
[rank0]: exec(code, run_globals)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/entrypoints/openai/api_server.py", line 196, in <module>
[rank0]: engine = AsyncLLMEngine.from_engine_args(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 398, in from_engine_args
[rank0]: engine = cls(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 349, in __init__
[rank0]: self.engine = self._init_engine(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/async_llm_engine.py", line 473, in _init_engine
[rank0]: return engine_class(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 236, in __init__
[rank0]: self._initialize_kv_caches()
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/engine/llm_engine.py", line 313, in _initialize_kv_caches
[rank0]: self.model_executor.determine_num_available_blocks())
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/distributed_gpu_executor.py", line 38, in determine_num_available_blocks
[rank0]: num_blocks = self._run_workers("determine_num_available_blocks", )
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/executor/multiproc_gpu_executor.py", line 119, in _run_workers
[rank0]: driver_worker_output = driver_worker_method(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/worker.py", line 162, in determine_num_available_blocks
[rank0]: self.model_runner.profile_run()
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 844, in profile_run
[rank0]: self.execute_model(seqs, kv_caches)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/utils/_contextlib.py", line 115, in decorate_context
[rank0]: return func(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/worker/model_runner.py", line 749, in execute_model
[rank0]: hidden_states = model_executable(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 371, in forward
[rank0]: hidden_states = self.model(input_ids, positions, kv_caches,
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 288, in forward
[rank0]: hidden_states, residual = layer(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 237, in forward
[rank0]: hidden_states = self.mlp(hidden_states)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/models/llama.py", line 79, in forward
[rank0]: gate_up, _ = self.gate_up_proj(x)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
[rank0]: return self._call_impl(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py", line 1541, in _call_impl
[rank0]: return forward_call(*args, **kwargs)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/linear.py", line 298, in forward
[rank0]: output_parallel = self.quant_method.apply(self, input_, bias)
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/gptq_marlin.py", line 431, in apply
[rank0]: marlin_scales = marlin_permute_scales(
[rank0]: File "/usr/local/lib/python3.10/dist-packages/vllm/model_executor/layers/quantization/gptq_marlin.py", line 48, in marlin_permute_scales
[rank0]: s = s.reshape((-1, len(scale_perm)))[:, scale_perm]
[rank0]: RuntimeError: CUDA error: an illegal memory access was encountered
[rank0]: CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
[rank0]: For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
[rank0]: Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
[rank0]:[E ProcessGroupNCCL.cpp:1414] [PG 2 Rank 0] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7bae35ccf897 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7bae35c7fb25 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7bae35da7718 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7bade9a4ae36 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7bade9a4ef38 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7bade9a545ac in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7bade9a5531c in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7bae354b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7bae369d1ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7bae36a62a04 in /usr/lib/x86_64-linux-gnu/libc.so.6)
[2024-06-23 19:53:43,178 E 7 3653] logging.cc:108: Unhandled exception: N3c1016DistBackendErrorE. what(): [PG 2 Rank 0] Process group watchdog thread terminated with exception: CUDA error: an illegal memory access was encountered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.
Exception raised from c10_cuda_check_implementation at ../c10/cuda/CUDAException.cpp:43 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7bae35ccf897 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: c10::detail::torchCheckFail(char const*, char const*, unsigned int, std::string const&) + 0x64 (0x7bae35c7fb25 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #2: c10::cuda::c10_cuda_check_implementation(int, char const*, char const*, int, bool) + 0x118 (0x7bae35da7718 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10_cuda.so)
frame #3: c10d::ProcessGroupNCCL::WorkNCCL::finishedGPUExecutionInternal() const + 0x56 (0x7bade9a4ae36 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #4: c10d::ProcessGroupNCCL::WorkNCCL::isCompleted() + 0x58 (0x7bade9a4ef38 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #5: c10d::ProcessGroupNCCL::watchdogHandler() + 0x77c (0x7bade9a545ac in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #6: c10d::ProcessGroupNCCL::ncclCommWatchdog() + 0x10c (0x7bade9a5531c in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #7: <unknown function> + 0xdc253 (0x7bae354b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #8: <unknown function> + 0x94ac3 (0x7bae369d1ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #9: clone + 0x44 (0x7bae36a62a04 in /usr/lib/x86_64-linux-gnu/libc.so.6)
Exception raised from ncclCommWatchdog at ../torch/csrc/distributed/c10d/ProcessGroupNCCL.cpp:1418 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::string) + 0x57 (0x7bae35ccf897 in /usr/local/lib/python3.10/dist-packages/torch/lib/libc10.so)
frame #1: <unknown function> + 0xe32e33 (0x7bade96d7e33 in /usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so)
frame #2: <unknown function> + 0xdc253 (0x7bae354b0253 in /usr/lib/x86_64-linux-gnu/libstdc++.so.6)
frame #3: <unknown function> + 0x94ac3 (0x7bae369d1ac3 in /usr/lib/x86_64-linux-gnu/libc.so.6)
frame #4: clone + 0x44 (0x7bae36a62a04 in /usr/lib/x86_64-linux-gnu/libc.so.6)
[2024-06-23 19:53:43,224 E 7 3653] logging.cc:115: Stack trace:
/usr/local/lib/python3.10/dist-packages/ray/_raylet.so(+0x1033f2a) [0x7bacd04e3f2a] ray::operator<<()
/usr/local/lib/python3.10/dist-packages/ray/_raylet.so(+0x1036f72) [0x7bacd04e6f72] ray::TerminateHandler()
/usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xae20c) [0x7bae3548220c]
/usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xae277) [0x7bae35482277]
/usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xae1fe) [0x7bae354821fe]
/usr/local/lib/python3.10/dist-packages/torch/lib/libtorch_cuda.so(+0xe32ee4) [0x7bade96d7ee4] c10d::ProcessGroupNCCL::ncclCommWatchdog()
/usr/lib/x86_64-linux-gnu/libstdc++.so.6(+0xdc253) [0x7bae354b0253]
/usr/lib/x86_64-linux-gnu/libc.so.6(+0x94ac3) [0x7bae369d1ac3]
/usr/lib/x86_64-linux-gnu/libc.so.6(clone+0x44) [0x7bae36a62a04] __clone
*** SIGABRT received at time=1719172423 on cpu 88 ***
PC: @ 0x7bae369d39fc (unknown) pthread_kill
@ 0x7bae3697f520 (unknown) (unknown)
[2024-06-23 19:53:43,225 E 7 3653] logging.cc:343: *** SIGABRT received at time=1719172423 on cpu 88 ***
[2024-06-23 19:53:43,225 E 7 3653] logging.cc:343: PC: @ 0x7bae369d39fc (unknown) pthread_kill
[2024-06-23 19:53:43,225 E 7 3653] logging.cc:343: @ 0x7bae3697f520 (unknown) (unknown)
Fatal Python error: Aborted
Extension modules: numpy.core._multiarray_umath, numpy.core._multiarray_tests, numpy.linalg._umath_linalg, numpy.fft._pocketfft_internal, numpy.random._common, numpy.random.bit_generator, numpy.random._bounded_integers, numpy.random._mt19937, numpy.random.mtrand, numpy.random._philox, numpy.random._pcg64, numpy.random._sfc64, numpy.random._generator, torch._C, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._nn, torch._C._sparse, torch._C._special, charset_normalizer.md, simplejson._speedups, requests.packages.charset_normalizer.md, requests.packages.chardet.md, yaml._yaml, psutil._psutil_linux, psutil._psutil_posix, msgpack._cmsgpack, google._upb._message, setproctitle, uvloop.loop, ray._raylet, sentencepiece._sentencepiece, ujson, regex._regex, scipy._lib._ccallback_c, numba.core.typeconv._typeconv, numba._helperlib, numba._dynfunc, numba._dispatcher, numba.core.runtime._nrt_python, numba.np.ufunc._internal, numba.experimental.jitclass._box, pyarrow.lib, pandas._libs.tslibs.ccalendar, pandas._libs.tslibs.np_datetime, pandas._libs.tslibs.dtypes, pandas._libs.tslibs.base, pandas._libs.tslibs.nattype, pandas._libs.tslibs.timezones, pandas._libs.tslibs.fields, pandas._libs.tslibs.timedeltas, pandas._libs.tslibs.tzconversion, pandas._libs.tslibs.timestamps, pandas._libs.properties, pandas._libs.tslibs.offsets, pandas._libs.tslibs.strptime, pandas._libs.tslibs.parsing, pandas._libs.tslibs.conversion, pandas._libs.tslibs.period, pandas._libs.tslibs.vectorized, pandas._libs.ops_dispatch, pandas._libs.missing, pandas._libs.hashtable, pandas._libs.algos, pandas._libs.interval, pandas._libs.lib, pyarrow._compute, pandas._libs.ops, pandas._libs.hashing, pandas._libs.arrays, pandas._libs.tslib, pandas._libs.sparse, pandas._libs.internals, pandas._libs.indexing, pandas._libs.index, pandas._libs.writers, pandas._libs.join, pandas._libs.window.aggregations, pandas._libs.window.indexers, pandas._libs.reshape, pandas._libs.groupby, pandas._libs.json, pandas._libs.parsers, pandas._libs.testing, pyarrow._parquet, pyarrow._fs, pyarrow._azurefs, pyarrow._hdfs, pyarrow._gcsfs, pyarrow._s3fs, multidict._multidict, yarl._quoting_c, aiohttp._helpers, aiohttp._http_writer, aiohttp._http_parser, aiohttp._websocket, frozenlist._frozenlist, xxhash._xxhash, pyarrow._json, markupsafe._speedups, PIL._imaging (total: 102)
Your current environment
🐛 Describe the bug
The GPTQ Marlin kernel fails with an invalid read on a model with
group_size=32
anddesc_act
with-tp 4
(-tp 2
works fine). For example:To make the bug easier to analyze, I made this minimal example that simulates the problematic shapes:
The issue only seems to be triggered by some
g_idx
arrays (hence the randomization + fixed seed). The issue doesn't occur with e.g.TP=2
orGROUP_SIZE=64
.