eth-sri / language-model-arithmetic

Controlled Text Generation via Language Model Arithmetic
MIT License
201 stars 15 forks source link

Beam search implementation #11

Closed Yuancheng-Xu closed 1 week ago

Yuancheng-Xu commented 1 week ago

Model arithmetic is a great codebase for combining signals from different LLMs and discriminators. Given the trend of doing search at inference time (such as MCTS, or something in this paper), I am wondering if you have any plans for implementing any search algorithm, such as beam search, base on this codebase. Thanks!

JasperDekoninck commented 1 week ago

Currently we do not have any plans to implement MCTS based algorithms or beam search. Unfortunately, adding these search based algorithms is not trivial with this specific library, due to our use of speculative sampling and key-value caching. Essentially, the implementation of search-based algorithms would require to bypass all things related to speculative sampling in our code-base, and would require a new (but not too difficult) implementation of key-value caching.

Yuancheng-Xu commented 1 week ago

Thanks for the reply! Could you elaborate a little bit on why search algorithm require a new implementation of key-value caching?

JasperDekoninck commented 1 week ago

I haven't thought it true exactly, but essentially, we have to dive a bit into the PromptedLLM class to see why: All lines between 290-490 there deal with the cache: store them appropriately, do some swapping of axis in case this is necessary (some models store their cache transposed to other models etc.), load them in when the input tokens overlap, etc. etc. When dealing with search-based algorithms, this code should be made to handle different stored caches, one for each beam in beam search for example. However, it should also appropriately delete and offload this cache to the cpu to ensure the GPU memory does not get overloaded. Its not really difficult though, but an extra step to take simply to implement stuff like beam search.

The more important issue is the speculative sampling. While the run_eager parameter in PromptedLLM and do_speculation in ModelArithmetic.generate_text should essentially disable this, I did not take into account the possibility of beam search when implementing this. Therefore, several aspects of it will still be lingering in the code that would prevent this option from working for beam search.

I now remember though that I did think about implementing the beam search algorithm after we subbed the paper. At that time, I found that its easiest implementation would be to drop everything related to speculative sampling from the code (which would simplify it a lot) and then start from there. Because of this, I did not go for it :)

Yuancheng-Xu commented 1 week ago

Thank you for the detailed response!! Indeed speculative sampling code is the issue here.