triton-lang / triton

Development repository for the Triton language and compiler
https://triton-lang.org/
MIT License
12.93k stars 1.57k forks source link

RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument #4128

Open radna0 opened 3 months ago

radna0 commented 3 months ago

When running the following code, with matmulfreellm package, triton-nightly:

import os

os.environ["TOKENIZERS_PARALLELISM"] = "false"
import mmfreelm
from transformers import AutoModelForCausalLM, AutoTokenizer

# Change here to our open-sourced model
name = "ridger/MMfreeLM-370M"
tokenizer = AutoTokenizer.from_pretrained(name)
model = AutoModelForCausalLM.from_pretrained(name).cuda().half()
input_prompt = "In a shocking finding, scientist discovered a herd of unicorns living in a remote, "
input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids.cuda()
outputs = model.generate(
    input_ids, max_length=32, do_sample=True, top_p=0.4, temperature=0.6
)
print(tokenizer.batch_decode(outputs, skip_special_tokens=True)[0])

I get this error:


/opt/conda/envs/py_3.9/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "/root/matmulfreellm/generate.py", line 13, in <module>
    outputs = model.generate(
  File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 316, in generate
    return super().generate(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 1525, in generate
    return self.sample(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 2622, in sample
    outputs = self(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 377, in forward
    outputs = self.model(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 253, in forward
    hidden_states, attentions, past_key_values = layer(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 101, in forward
    hidden_states = self.attn_norm(hidden_states)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 615, in forward
    return rms_norm_fn(
  File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 543, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/root/matmulfreellm/mmfreelm/utils.py", line 9, in wrapper
    return fn(ctx,
  File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 471, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 203, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 209, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 120, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/testing.py", line 103, in do_bench
    fn()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 105, in kernel_call
    self.fn.run(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 548, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/backends/amd/driver.py", line 418, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]:  Code: 1, Messsage: invalid argument```
ajassani commented 3 months ago

The autotune code tries to launch a kernel with num warps 32. Since warp size is 64 on AMD instinct devices, the total threads per block 32*64=2048 exceeds the max threads per block of 1024. We made a temporary fix while porting mamba: https://github.com/state-spaces/mamba/blob/ddce0c1334536dd04c523ccce08928f3611d2627/mamba_ssm/ops/triton/layer_norm.py#L128C5-L128C17

In future Triton releases, this config pruning should happen within Triton.

radna0 commented 3 months ago

Here's the code, I tried updating the autotune, but it still has the same error:

root@r4-0:/home/matmulfreellm# python generate.py
/home/matmulfreellm/mmfreelm/modules/fused_norm_gate.py:75: UserWarning: 'torch._C._CudaDeviceProperties' object has no attribute 'gcnArchName', warp size set to 32 based on device name: Radeon RX Vega
  warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
/home/matmulfreellm/mmfreelm/modules/layernorm.py:111: UserWarning: 'torch._C._CudaDeviceProperties' object has no attribute 'gcnArchName', warp size set to 32 based on device name: Radeon RX Vega
  warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
/home/matmulfreellm/mmfreelm/ops/fusedbitnet.py:70: UserWarning: 'torch._C._CudaDeviceProperties' object has no attribute 'gcnArchName', warp size set to 32 based on device name: Radeon RX Vega
  warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "/home/matmulfreellm/generate.py", line 11, in <module>
    outputs = model.generate(input_ids, max_length=32,  do_sample=True, top_p=0.4, temperature=0.6)
  File "/home/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 316, in generate
    return super().generate(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 1525, in generate
    return self.sample(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 2622, in sample
    outputs = self(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 377, in forward
    outputs = self.model(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 253, in forward
    hidden_states, attentions, past_key_values = layer(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 101, in forward
    hidden_states = self.attn_norm(hidden_states)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/matmulfreellm/mmfreelm/modules/layernorm.py", line 643, in forward
    return rms_norm_fn(
  File "/home/matmulfreellm/mmfreelm/modules/layernorm.py", line 571, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/matmulfreellm/mmfreelm/utils.py", line 9, in wrapper
    return fn(ctx,
  File "/home/matmulfreellm/mmfreelm/modules/layernorm.py", line 499, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/home/matmulfreellm/mmfreelm/modules/layernorm.py", line 239, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 209, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 120, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/testing.py", line 103, in do_bench
    fn()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 105, in kernel_call
    self.fn.run(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 548, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/backends/amd/driver.py", line 418, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]:  Code: 1, Messsage: invalid argument