huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
133.75k stars 26.73k forks source link

CFG for RWKV models #26089

Closed KnutJaegersberg closed 11 months ago

KnutJaegersberg commented 1 year ago

System Info

Who can help?

@younesbelkada

Information

Tasks

Reproduction

I got this error in the textgen-webui using HF transformers converted RWKV models, i.e. RWKV/rwkv-4-1b5-pile, but I think it is a HF tf issue, i.e. "TypeError: RwkvForCausalLM.forward() got an unexpected keyword argument 'past_key_values'". Perhaps it's not implemented yet? I used the options for CFG (and also contrastive search, but not in the same go), and for CFG, I got this error message: CFG just didn't work without error message.

Output generated in 0.65 seconds (0.00 tokens/s, 0 tokens, context 49, seed 499979764)
Traceback (most recent call last):
  File "/run/media/knut/HD/text-generation-webui/modules/callbacks.py", line 56, in gentask
    ret = self.mfunc(callback=_callback, *args, **self.kwargs)
  File "/run/media/knut/HD/text-generation-webui/modules/text_generation.py", line 321, in generate_with_callback
    shared.model.generate(**kwargs)
  File "/home/knut/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/utils/_contextlib.py", line 115, in decorate_context
    return func(*args, **kwargs)
  File "/home/knut/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/generation/utils.py", line 1648, in generate
    return self.sample(
  File "/home/knut/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/generation/utils.py", line 2743, in sample
    next_token_scores = logits_processor(input_ids, next_token_logits)
  File "/home/knut/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 97, in __call__
    scores = processor(input_ids, scores)
  File "/home/knut/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 1655, in __call__
    logits = self.get_unconditional_logits(input_ids)
  File "/home/knut/miniconda3/envs/textgen/lib/python3.10/site-packages/transformers/generation/logits_process.py", line 1640, in get_unconditional_logits
    out = self.model(
  File "/home/knut/miniconda3/envs/textgen/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
TypeError: RwkvForCausalLM.forward() got an unexpected keyword argument 'past_key_values'
Output generated in 0.63 seconds (0.00 tokens/s, 0 tokens, context 49, seed 1428846449)

Expected behavior

Generate some nice CFGded tokens. Also contrastive search tokens.

amyeroberts commented 1 year ago

cc @ArthurZucker @gante

ArthurZucker commented 1 year ago

Hey! Could you provide a full reproducer? past_key_values should be supported as its required for fast generation using use_cache=True!

gante commented 1 year ago

Hey @KnutJaegersberg 👋

The root issue is that RWKV, being an RNN at its core, does not have a growing key-value cache (past_key_values) that can be sliced. Alternatively, it has the state of the recurrent neural net, which is updated at each iteration of generation.

Since the implementation of CFG and contrastive search (and some other methods) rely on the ability to slice the cache to remove old data, there is no immediate solution for RWKV.

You probably can implement equivalent versions of these techniques for models that have a state (as opposed to a growing cache), by recomputing the RWKV state as needed :)

github-actions[bot] commented 11 months ago

This issue has been automatically marked as stale because it has not had recent activity. If you think this still needs to be addressed please comment on this thread.

Please note that issues that do not follow the contributing guidelines are likely to be ignored.