Open Wintoplay opened 2 weeks ago
Hi, sorry now we are not able to support ROCm devices, since we are using Triton, which only support CUDA.
Try triton git main branch, latest code supports RDNA3 and CDNA2+3
https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/driver.py#L434
@ridgerchu Please add ROCM support for Triton, there's this official repo that supports Triton with ROCM https://github.com/ROCm/triton Here is an example of the flash attention v2 implementation via triton with rocm support: https://github.com/ROCm/triton/blob/triton-mlir/python/tutorials/06-fused-attention.py
@radna0 I don't have an AMD GPU to test with, but if you do, maybe put in a PR to add ROCM support?
Will gladly do!
-Triton-nightly 3.0.0
might anyone be of help regarding this
root@r4-0:~/matmulfreellm# python generate.py
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
warnings.warn(
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
File "/root/matmulfreellm/generate.py", line 13, in <module>
outputs = model.generate(
File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 316, in generate
return super().generate(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
return func(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 1525, in generate
return self.sample(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 2622, in sample
outputs = self(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 377, in forward
outputs = self.model(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 253, in forward
hidden_states, attentions, past_key_values = layer(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 101, in forward
hidden_states = self.attn_norm(hidden_states)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
return forward_call(*args, **kwargs)
File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 615, in forward
return rms_norm_fn(
File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 543, in rms_norm_fn
return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
return super().apply(*args, **kwargs) # type: ignore[misc]
File "/root/matmulfreellm/mmfreelm/utils.py", line 9, in wrapper
return fn(ctx,
File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 471, in forward
y, mean, rstd, residual_out = _layer_norm_fwd(
File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 203, in _layer_norm_fwd
_layer_norm_fwd_1pass_kernel[(M,)](
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 209, in <lambda>
return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in run
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 120, in _bench
return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/testing.py", line 103, in do_bench
fn()
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 105, in kernel_call
self.fn.run(
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 548, in run
kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/backends/amd/driver.py", line 418, in __call__
self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]: Code: 1, Messsage: invalid argument```
I might be wrong but seems like a driver error...?
It seems like it gets all the way to "instantiating the layers", but then it tries to run some sort of benchmark...and then, when it tries to run the kernel in the GPU, it crashes here:
File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/backends/amd/driver.py", line 418, in __call__
self.launch(*args, **kwargs)
Idk if the changes have been merged upsteam, but are you using the ROCm fork? This one: https://github.com/ROCm/triton Not this one: https://github.com/triton-lang/triton
There are 3 versions of Triton that I have tried: -Triton-nightly 3.0.0, the error I have is above
-Triton 2.1.0, which is the ROCM/triton repo/branch that you just mentioned @taylor-shift, there are errors regarding the launcher implementation, I think it has not been implemented until Triton 2.2 which is also the required version for the matmulfreellm package
-Triton 2.3.1, the latest stable version, I get the same error as @Wintoplay
@radna0 I only have triton 3 that come with pytorch for Rocm. Btw, I wonder if matmulfreellm can be implemented in Jax efficiently.
I haven't used Jax, so I don't know, I ran the official rocm-triton docker image and then ran my test cases from there.
https://github.com/ridgerchu/matmulfreellm/issues/17
Still there are some verified concerns from @ridgerchu saying that the current matmulfreellm repo can only do training but not inference, because they still would have to update the package to work with BitBlas, I also have check with the BitBlas team and they said there's a branch that works with HIP
https://github.com/microsoft/BitBLAS/issues/55
Thanks for your all attention on BitBLAS dudes :), we've update the citation in our repo:
@inproceedings {ladder-osdi24,
author = {Lei Wang and Lingxiao Ma and Shijie Cao and Quanlu Zhang and Jilong Xue and Yining Shi and Ningxin Zheng and Ziming Miao and Fan Yang and Ting Cao and Yuqing Yang and Mao Yang},
title = {Ladder: Enabling Efficient Low-Precision Deep Learning Computing through Hardware-aware Tensor Transformation},
booktitle = {18th USENIX Symposium on Operating Systems Design and Implementation (OSDI 24)},
year = {2024},
url = {https://www.usenix.org/conference/osdi24/presentation/wang-lei},
}
Hey @LeiWang1999, thanks for your info! We will update the reference when next version of our paper released!