Open radna0 opened 5 months ago
The autotune code tries to launch a kernel with num warps 32. Since warp size is 64 on AMD instinct devices, the total threads per block 32*64=2048 exceeds the max threads per block of 1024. We made a temporary fix while porting mamba: https://github.com/state-spaces/mamba/blob/ddce0c1334536dd04c523ccce08928f3611d2627/mamba_ssm/ops/triton/layer_norm.py#L128C5-L128C17
In future Triton releases, this config pruning should happen within Triton.
Here's the code, I tried updating the autotune, but it still has the same error:
root@r4-0:/home/matmulfreellm# python generate.py
/home/matmulfreellm/mmfreelm/modules/fused_norm_gate.py:75: UserWarning: 'torch._C._CudaDeviceProperties' object has no attribute 'gcnArchName', warp size set to 32 based on device name: Radeon RX Vega
warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
/home/matmulfreellm/mmfreelm/modules/layernorm.py:111: UserWarning: 'torch._C._CudaDeviceProperties' object has no attribute 'gcnArchName', warp size set to 32 based on device name: Radeon RX Vega
warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
/home/matmulfreellm/mmfreelm/ops/fusedbitnet.py:70: UserWarning: 'torch._C._CudaDeviceProperties' object has no attribute 'gcnArchName', warp size set to 32 based on device name: Radeon RX Vega
warnings.warn(f"{e}, warp size set to {warp_size} based on device name: {device_name}", UserWarning)
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
File "/home/matmulfreellm/generate.py", line 11, in <module>
outputs = model.generate(input_ids, max_length=32, do_sample=True, top_p=0.4, temperature=0.6)
File "/home/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 316, in generate
return super().generate(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 1525, in generate
return self.sample(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 2622, in sample
outputs = self(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 377, in forward
outputs = self.model(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 253, in forward
hidden_states, attentions, past_key_values = layer(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 101, in forward
hidden_states = self.attn_norm(hidden_states)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/home/matmulfreellm/mmfreelm/modules/layernorm.py", line 643, in forward
return rms_norm_fn(
File "/home/matmulfreellm/mmfreelm/modules/layernorm.py", line 571, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/home/matmulfreellm/mmfreelm/utils.py", line 9, in wrapper
return fn(ctx,
File "/home/matmulfreellm/mmfreelm/modules/layernorm.py", line 499, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/home/matmulfreellm/mmfreelm/modules/layernorm.py", line 239, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 209, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 120, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/testing.py", line 103, in do_bench
fn()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 105, in kernel_call
self.fn.run(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 548, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/backends/amd/driver.py", line 418, in __call__
self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument
When running the following code, with matmulfreellm package, triton-nightly:
I get this error: