ridgerchu / matmulfreellm

Implementation for MatMul-free LM.
Apache License 2.0
2.35k stars 131 forks source link

Rocm(7900xtx) GPU fail #3

Open Wintoplay opened 2 weeks ago

Wintoplay commented 2 weeks ago

image

ridgerchu commented 2 weeks ago

Hi, sorry now we are not able to support ROCm devices, since we are using Triton, which only support CUDA.

chicheng commented 2 weeks ago

Try triton git main branch, latest code supports RDNA3 and CDNA2+3

https://github.com/triton-lang/triton/blob/main/third_party/amd/backend/driver.py#L434

radna0 commented 2 weeks ago

@ridgerchu Please add ROCM support for Triton, there's this official repo that supports Triton with ROCM https://github.com/ROCm/triton Here is an example of the flash attention v2 implementation via triton with rocm support: https://github.com/ROCm/triton/blob/triton-mlir/python/tutorials/06-fused-attention.py

taylor-shift commented 2 weeks ago

@radna0 I don't have an AMD GPU to test with, but if you do, maybe put in a PR to add ROCM support?

radna0 commented 2 weeks ago

Will gladly do!

radna0 commented 2 weeks ago

-Triton-nightly 3.0.0

might anyone be of help regarding this


root@r4-0:~/matmulfreellm# python generate.py
/opt/conda/envs/py_3.9/lib/python3.9/site-packages/huggingface_hub/file_download.py:1132: FutureWarning: `resume_download` is deprecated and will be removed in version 1.0.0. Downloads always resume when possible. If you want to force a new download, use `force_download=True`.
  warnings.warn(
The attention mask and the pad token id were not set. As a consequence, you may observe unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results.
Setting `pad_token_id` to `eos_token_id`:2 for open-end generation.
Traceback (most recent call last):
  File "/root/matmulfreellm/generate.py", line 13, in <module>
    outputs = model.generate(
  File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 316, in generate
    return super().generate(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 1525, in generate
    return self.sample(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/transformers/generation/utils.py", line 2622, in sample
    outputs = self(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 377, in forward
    outputs = self.model(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 253, in forward
    hidden_states, attentions, past_key_values = layer(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/matmulfreellm/mmfreelm/models/hgrn_bit/modeling_hgrn_bit.py", line 101, in forward
    hidden_states = self.attn_norm(hidden_states)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 615, in forward
    return rms_norm_fn(
  File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 543, in rms_norm_fn
    return LayerNormFn.apply(x, weight, bias, residual, eps, prenorm, residual_in_fp32, True)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/root/matmulfreellm/mmfreelm/utils.py", line 9, in wrapper
    return fn(ctx,
  File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 471, in forward
    y, mean, rstd, residual_out = _layer_norm_fwd(
  File "/root/matmulfreellm/mmfreelm/modules/layernorm.py", line 203, in _layer_norm_fwd
    _layer_norm_fwd_1pass_kernel[(M,)](
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 209, in <lambda>
    return lambda *args, **kwargs: self.run(grid=grid, warmup=False, *args, **kwargs)
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in run
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 143, in <dictcomp>
    timings = {config: self._bench(*args, config=config, **kwargs) for config in pruned_configs}
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 120, in _bench
    return do_bench(kernel_call, warmup=self.num_warmups, rep=self.num_reps, quantiles=(0.5, 0.2, 0.8))
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/testing.py", line 103, in do_bench
    fn()
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/autotuner.py", line 105, in kernel_call
    self.fn.run(
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/runtime/jit.py", line 548, in run
    kernel.run(grid_0, grid_1, grid_2, stream, kernel.function, kernel.packed_metadata, launch_metadata,
  File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/backends/amd/driver.py", line 418, in __call__
    self.launch(*args, **kwargs)
RuntimeError: Triton Error [HIP]:  Code: 1, Messsage: invalid argument```
taylor-shift commented 1 week ago

I might be wrong but seems like a driver error...?

It seems like it gets all the way to "instantiating the layers", but then it tries to run some sort of benchmark...and then, when it tries to run the kernel in the GPU, it crashes here:

File "/opt/conda/envs/py_3.9/lib/python3.9/site-packages/triton/backends/amd/driver.py", line 418, in __call__
    self.launch(*args, **kwargs)

Idk if the changes have been merged upsteam, but are you using the ROCm fork? This one: https://github.com/ROCm/triton Not this one: https://github.com/triton-lang/triton

radna0 commented 1 week ago

There are 3 versions of Triton that I have tried: -Triton-nightly 3.0.0, the error I have is above

-Triton 2.1.0, which is the ROCM/triton repo/branch that you just mentioned @taylor-shift, there are errors regarding the launcher implementation, I think it has not been implemented until Triton 2.2 which is also the required version for the matmulfreellm package

-Triton 2.3.1, the latest stable version, I get the same error as @Wintoplay

image

Wintoplay commented 1 week ago

@radna0 I only have triton 3 that come with pytorch for Rocm. Btw, I wonder if matmulfreellm can be implemented in Jax efficiently.

radna0 commented 1 week ago

I haven't used Jax, so I don't know, I ran the official rocm-triton docker image and then ran my test cases from there.

image

https://github.com/ridgerchu/matmulfreellm/issues/17

Still there are some verified concerns from @ridgerchu saying that the current matmulfreellm repo can only do training but not inference, because they still would have to update the package to work with BitBlas, I also have check with the BitBlas team and they said there's a branch that works with HIP

https://github.com/microsoft/BitBLAS/issues/55

image

LeiWang1999 commented 1 week ago

Thanks for your all attention on BitBLAS dudes :), we've update the citation in our repo:

@inproceedings {ladder-osdi24,
author = {Lei Wang and Lingxiao Ma and Shijie Cao and Quanlu Zhang and Jilong Xue and Yining Shi and Ningxin Zheng and Ziming Miao and Fan Yang and Ting Cao and Yuqing Yang and Mao Yang},
title = {Ladder: Enabling Efficient Low-Precision Deep Learning Computing through Hardware-aware Tensor Transformation},
booktitle = {18th USENIX Symposium on Operating Systems Design and Implementation (OSDI 24)},
year = {2024},
url = {https://www.usenix.org/conference/osdi24/presentation/wang-lei},
}
ridgerchu commented 1 week ago

Hey @LeiWang1999, thanks for your info! We will update the reference when next version of our paper released!