BlinkDL / ChatRWKV

ChatRWKV is like ChatGPT but powered by RWKV (100% RNN) language model, and open source.
Apache License 2.0
9.42k stars 697 forks source link

M1 Max MPS F32 / F16 Issues #5

Closed okpatil4u closed 1 year ago

okpatil4u commented 1 year ago

Hello @BlinkDL,

As per your recommendation, I was able to run this on MPS at half precision. It gets stuck on MPS at full precision.

On 64 GB M1 Max CPU, 14B model gives pretty good results, but it is pretty slow. When I made a few changes to get it working on MPS, it's very fast. But the results are the worst. For example, on MPS F16, it generates,

User: +gen Here is a short story in which Jeff Bezos, Elon Musk, and Bill Gates fight in a tournament:
ChDefPSt
QTheThisSCheckReviewQBackgroundInformationThisPaulQSp1QThe1------[How{The#QTheHSamAfterQMfileCharacterGQDemEQAfterThe//QWeQ"ThisIntroductionCorListQQQ(EAtBackgroundEnAn[BrWithDirectGWomenNQTheOh1Last#OnQGlQQWilliamWhy9                        QTheQAfterMelCheckQTQ/*BlYouFieldAQThe/*PrThe 

On the other hand, with same code and on CPU F32, it generates,

User: +gen Here is a short story in which Jeff Bezos, Elon Musk, and Bill Gates fight in a tournament:

There are four kings in the chess world, and every four years a World Championship takes place. In 2018, Bezos defeated Musk in the semi-finals, while Gates took down both of them. So now it’s Bezos and Musk who are playing each other.

What am I missing ?

The most significant change was in sample_logits function, where I used the same probability algorithm for both MPS and CPU. Rest of the changes included only changing the device from CUDA to MPS.

BlinkDL commented 1 year ago

"very fast" means the inference is buggy and not really computing the correct result. MPS is known to be still buggy.

okpatil4u commented 1 year ago

Thanks @BlinkDL. How do I debug this ? Where should I start ? Can you give me a few pointers ?

FreeBlues commented 1 year ago

@okpatil4u Can you show how to make it work in mps? When I try to set it in mps(MBP intel x86 CPU + eGPU[RX6800 16G]), it return a very big token, just like below:

in chat.py

args.strategy = 'mps fp32'

error log:

Bob: 企鹅会飞吗
Alice: 企鹅是不会飞的。企鹅的翅膀短而扁平,更像是游泳时的一对桨。企鹅的身体结构和羽毛密度也更适合在水中游泳,而不是飞行。

Bob: hi

 len(tokens), tokens[]: 8 [26845, 27, 14260, 187, 187, 2422, 547, 27]
Alice:
 token, out: -9223372036854775808 tensor([ -5.4473, -25.4831,  -6.7508,  ...,  -7.4079,  -6.1975,  -4.1842],
       device='mps:0')

 len(tokens), tokens[]: 1 [-9223372036854775808]
Traceback (most recent call last):
  File "/Users/ppt/Github/ChatRWKV/v2/chat.py", line 474, in <module>
    on_message(msg)
  File "/Users/ppt/Github/ChatRWKV/v2/chat.py", line 387, in on_message
    out = run_rnn([token], newline_adj=newline_adj)
  File "/Users/ppt/Github/ChatRWKV/v2/chat.py", line 163, in run_rnn
    out, model_state = model.forward(tokens[:CHUNK_LEN], model_state)
  File "/Users/ppt/Github/ChatRWKV/v2/../rwkv_pip_package/src/rwkv/model.py", line 563, in forward
    x = w['emb.weight'][tokens if seq_mode else tokens[0]]
RuntimeError: Expected !is_symbolic() to be true, but got false.  (Could this error message be improved?  If so, please report an enhancement request to PyTorch.)
(ChatRWKV) ppt@pptdeMacBook-Pro v2 % python chat.py

ChatRWKV v2 https://github.com/BlinkDL/ChatRWKV

Chinese - mps:0 fp32 - /Users/ppt/Github/ChatRWKV/v2/prompt/default/Chinese-2.py
Loading model - ./fsx/BlinkDL/HF-MODEL/rwkv-4-pile-1b5/RWKV-4-Pile-1B5-EngChn-testNovel-done-ctx2048-20230225
Traceback (most recent call last):
  File "/Users/ppt/Github/ChatRWKV/v2/chat.py", line 133, in <module>
    model = RWKV(model=args.MODEL_NAME, strategy=args.strategy)
  File "/Users/ppt/miniconda/envs/ChatRWKV/lib/python3.10/site-packages/torch/jit/_script.py", line 292, in init_then_script
    original_init(self, *args, **kwargs)
  File "/Users/ppt/Github/ChatRWKV/v2/../rwkv_pip_package/src/rwkv/model.py", line 84, in __init__
    raise ValueError("Invalid strategy. Please read https://pypi.org/project/rwkv/")
ValueError: Invalid strategy. Please read https://pypi.org/project/rwkv/
(ChatRWKV) ppt@pptdeMacBook-Pro v2 %