FasterDecoding / REST

REST: Retrieval-Based Speculative Decoding, NAACL 2024
Apache License 2.0
174 stars 10 forks source link

What are the meanings of each parameter of the reader.search() function? k\choices\long #12

Closed yangbohust closed 2 days ago

yangbohust commented 6 months ago

https://github.com/FasterDecoding/REST/blob/6aed6ad5beb11849adfe671e874c239461ee8b84/DraftRetriever/src/lib.rs#L209-L215

Sorry, I am not familiar with rust. What are the meanings of each parameter of the reader.search() function? I'm very much looking forward to your reply. Thanks.

zhenyuhe00 commented 6 months ago

Hi, "py_substring" is the query (Type: python List) to search in the datastore. "k" is the maximum number of returned sequences. I set it to 5000 for each thread by default for efficiency. "choices" is the number of draft tokens. "long" is the cut length of each return sequence. If you have any further questions, please feel free to contact me.

yangbohust commented 6 months ago

Hi, I tried to use your database construction method on the qwen-7b model ( The "vocab_size" of the qwen-7b model is 151936), and then tested it with the human-eval data set. I found that almost none of the draft tokens retrieved from the database were correct.

code:

from datasets import load_dataset
# from transformers import AutoTokenizer
from tokenization_qwen import QWenTokenizer

import draftretriever
from tqdm import tqdm
import argparse
parser = argparse.ArgumentParser()

parser.add_argument(
    "--model-path",
    type=str,
    default="/home/models/qwen-7b",
    help="The path to the weights. This can be a local folder or a Hugging Face repo ID.",
)
parser.add_argument(
    "--large-datastore",
    type=bool,
    default=False,
    help="Whether to use a large datastore",
)
args = parser.parse_args()
print(args)

tokenizer = QWenTokenizer.from_pretrained(args.model_path)
segment = 1 # Maximum number of segment: 144
data_files = []
for i in range(segment):
    if i>=100:
        data_files.append(f"data-00{i}-of-00144.parquet")
    elif i >=10:
        data_files.append(f"data-000{i}-of-00144.parquet")
    else:
        data_files.append(f"data-0000{i}-of-00144.parquet")
print("data_files:", data_files)

dataset = load_dataset('bigcode/the-stack-dedup', data_dir='data/python', split='train', data_files=data_files)

datastore_path = './datastore_stack_large.idx' if args.large_datastore else f'./datastore_stack_small_the-stack-dedup_python_0_{segment-1}.idx'

writer = draftretriever.Writer(
    index_file_path=datastore_path,
    max_chunk_len=512 * 1024 * 1024,
    vocab_size=tokenizer.vocab_size,
)

total_length = len(dataset)
print("number of samples: ", total_length)

for sample in tqdm(dataset, total=len(dataset)):
    token_list = tokenizer.encode(sample['content'])
    writer.add_entry(token_list)

writer.finalize()

Is my database constructed incorrectly? Or this method cannot be used for qwen-7b model?

zhenyuhe00 commented 6 months ago

I assume REST should apply to any LLM. However, in this codebase, we've only provided code with LLaMA-based architectures in modeling_llama_kv.py (marked with [MODIFIED]). For other architectures such as Qwen, you may modify modeling_qwen.py yourself. It should be easy with only a few lines of code.

If you have any further questions, please feel free to contact me.

yangbohust commented 6 months ago

The "vocab_size" of the qwen-7b model is 151936, which exceeds the maximum value of 65536 that can be represented by 2 bytes. I have made the following modifications to the lib.rs file. Please help me check whether the modifications are correct and whether there are any I missed something, thank you very much~

Original code: https://github.com/FasterDecoding/REST/blob/6aed6ad5beb11849adfe671e874c239461ee8b84/DraftRetriever/src/lib.rs#L107-L128

Modified code:

    fn dump_data(
        &mut self,
    ) -> PyResult<()> {
        if self.buffer.is_empty() {
            return Ok(());
        }

        self.index_file.write_u32::<LittleEndian>((self.buffer.len() * 4) as u32)?;

        for &item in &self.buffer {
            self.index_file.write_u32::<LittleEndian>(item as u32)?;
        }

        let suffix_array = construct_suffix_array(&self.buffer, self.vocab_size);
        self.index_file.write_u32::<LittleEndian>((suffix_array.len() * 4) as u32)?;
        for suffix in suffix_array {
            self.index_file.write_i32::<LittleEndian>(suffix)?;
        }
        self.buffer.clear();

        Ok(())
    }

Original code: https://github.com/FasterDecoding/REST/blob/6aed6ad5beb11849adfe671e874c239461ee8b84/DraftRetriever/src/lib.rs#L191-L194

Modified code:

            for i in (0..data_u8.len()).step_by(4) {
                let int = LittleEndian::read_u32(&data_u8[i..i+4]) as i32;
                data.push(int);
            }
yangbohust commented 6 months ago

When building a data set for the qwen-7b model, the following error message is reported. I wonder if it will affect the construction and retrieval results of the data set?

Token indices sequence length is longer than the specified maximum sequence length for this model (53946 > 32768). Running this sequence through the model will result in indexing errors
zhenyuhe00 commented 6 months ago

The "vocab_size" of the qwen-7b model is 151936, which exceeds the maximum value of 65536 that can be represented by 2 bytes. I have made the following modifications to the lib.rs file. Please help me check whether the modifications are correct and whether there are any I missed something, thank you very much~

Original code:

https://github.com/FasterDecoding/REST/blob/6aed6ad5beb11849adfe671e874c239461ee8b84/DraftRetriever/src/lib.rs#L107-L128

Modified code:

    fn dump_data(
        &mut self,
    ) -> PyResult<()> {
        if self.buffer.is_empty() {
            return Ok(());
        }

        self.index_file.write_u32::<LittleEndian>((self.buffer.len() * 4) as u32)?;

        for &item in &self.buffer {
            self.index_file.write_u32::<LittleEndian>(item as u32)?;
        }

        let suffix_array = construct_suffix_array(&self.buffer, self.vocab_size);
        self.index_file.write_u32::<LittleEndian>((suffix_array.len() * 4) as u32)?;
        for suffix in suffix_array {
            self.index_file.write_i32::<LittleEndian>(suffix)?;
        }
        self.buffer.clear();

        Ok(())
    }

Original code:

https://github.com/FasterDecoding/REST/blob/6aed6ad5beb11849adfe671e874c239461ee8b84/DraftRetriever/src/lib.rs#L191-L194

Modified code:

            for i in (0..data_u8.len()).step_by(4) {
                let int = LittleEndian::read_u32(&data_u8[i..i+4]) as i32;
                data.push(int);
            }

I think the modifications are reasonable.

zhenyuhe00 commented 6 months ago

When building a data set for the qwen-7b model, the following error message is reported. I wonder if it will affect the construction and retrieval results of the data set?

Token indices sequence length is longer than the specified maximum sequence length for this model (53946 > 32768). Running this sequence through the model will result in indexing errors

I assume this message only indicates that the context length of Qwen is 32768. But running tokenization beyond 32768 should not be a problem, or maybe you can truncate the datastore like https://github.com/FasterDecoding/REST/issues/11.

If you have any further questions, please feel free to contact me.

yangbohust commented 6 months ago

Hello, have you paid attention to the paper LOOKAHEAD DECODING(Break the Sequential Dependency of LLM Inference Using LOOKAHEAD DECODING)? LOOKAHEAD DECODING generates draft tokens through multiple rounds of jacobi iterations at fixed points, which can make full use of the historical information generated by the model, and your method is to retrieve the draft token from the database. May I ask what are the advantages and disadvantages of the REST solution and the LOOKAHEAD DECODING solution, and what scenarios are each suitable for?

1) I compared and tested the LOOKAHEAD DECODING and the REST on the human-eval data set, and found that the LOOKAHEAD DECODING has a better acceleration effect, which means that more draft tokens are guessed correctly.

2) qwen-7b model adapts to REST. Construct datastore using only one parquet file(data-00000-of-00144.parquet) test dataset: HumanEval.jsonl.gz

max-new-token = 512 temperature = 0 top-p = 0 num-draft = 64 max-token-span = 16

Running rest_test.py gets the following results

accept_lengths_tree_average:  1.8346164873655157
accept_lengths_tree_average_micro:  1.6333190578158459
avg_time_per_token:  0.021387115
avg_time_per_token_micro:  0.02107867338941321
******************************

Does accept_lengths_tree_average_micro mean that on average, ~1.633 tokens are generated per round of decoding ?

There are a total of 45271 decoding steps, of which Accept_length=0 was used 29064 times, accounting for 64.2% Accept_length=1 was used 10775 times, accounting for 23.8% Accept_length=2 was used 2855 times, accounting for 6.3% Accept_length=3 has 859 times, accounting for 1.897% Accept_length=4 has 626 times, accounting for 1.3827% There are 372 times accept_length=5 There are 167 times accept_length=6 There are 185 times accept_length=7 There are 62 times accept_length=8 There are 98 times accept_length=9 There are 208 times accept_length=10

This means that 64.2% of the decoding steps did not correctly guess the draft token.

Thank you very much~

zhenyuhe00 commented 6 months ago
  1. what are the advantages and disadvantages of the REST solution and the LOOKAHEAD DECODING solution, and what scenarios are each suitable for?

REST requires an additional datastore for retrieval, whereas LOOKAHEAD DECODING does not. In scenarios where disk storage is limited, you may consider using LOOKAHEAD DECODING. Conversely, when disk storage is ample, REST may be a suitable option.

  1. I compared and tested the LOOKAHEAD DECODING and the REST on the human-eval data set, and found that the LOOKAHEAD DECODING has a better acceleration effect, which means that more draft tokens are guessed correctly.

You may consider scaling up the size of the datastore to achieve faster speedup with REST.

  1. Does accept_lengths_tree_average_micro mean that on average, ~1.633 tokens are generated per round of decoding ?

Yes. According to your statistics: 1*0.642+2*0.238+3*0.063+4*0.01897+5*0.013827+6*372/45271+7*167/45271+8*185/45271+9*62/45271+10*98/45271+11*208/45271=1.644