stanfordnlp / pyreft

ReFT: Representation Finetuning for Language Models
https://arxiv.org/abs/2404.03592
Apache License 2.0
1.12k stars 93 forks source link

[P1] Possible to do batch inference? #105

Open thistleknot opened 3 months ago

thistleknot commented 3 months ago

I'm doing this atm

for q_ in tqdm(rando):
    #print('quote:',q_)
    quotes_fol = []
    quotes_nodes_edges = []
    sentences = sent_tokenize(q_)
    for q in sentences:
        # tokenize and prepare the input
        prompt = prompt_no_input_template % q
        prompt = tokenizer(prompt, return_tensors="pt").to(device)

        unit_locations = torch.IntTensor([pyreft.get_intervention_locations(
            last_position=prompt["input_ids"].shape[-1], 
            first_n=first_n, 
            last_n=last_n,
            pad_mode="last",
            num_interventions=len(reft_config.representations),
            share_weights=share_weights
        )]).permute(1, 0, 2).tolist()

        # Generate with beam search
        _, reft_response = reft_model.generate(
            prompt, 
            unit_locations={"sources->base": (None, unit_locations)},
            intervene_on_prompt=True, 
            max_new_tokens=537, 
            do_sample=True,
            top_k=50,
            temperature=0.7,
            num_beams=5,  # Using beam search with 5 beams
            eos_token_id=terminators, 
            early_stopping=True
        )
        response = tokenizer.decode(reft_response[0], skip_special_tokens=True)
        #print(response)
        #out = lcel_chain.invoke({"input": response})
        #print('node/csv:',out)
        quotes_fol.append(response)

        #quotes_nodes_edges.append(out)
    quotes_fol_.append(quotes_fol)
    #quotes_nodes_edges_.append(quotes_nodes_edges)

but i'd like to escape the iteration, and I'm not sure how to format unit_locations. Normally one would do something like model.generate(**inputs), but this being pyreft, I'm not sure if that is supported as it's a custom class (I haven't delved into the class for this specific feature).

Thought I'd ask first as well as for visibility for others who might be interested.

frankaging commented 3 months ago

@thistleknot Yes, it supports batched inference calls.

You can take a look at this function for batching: https://github.com/stanfordnlp/pyreft/blob/main/examples/loreft/compute_metrics.py#L111

In a nutshell, you need to apply left padding to your tokenizer and calculate the batched intervention locations accordingly.

thistleknot commented 3 months ago

'calculate the batched intervention locations accordingly.'

that doesn't sound easy.

I'm not sure if I can use the same -1 position as I was before for each prompt... or if it's expecting it to be where it is within the batch tensor.

thistleknot commented 3 months ago

you able to help a brother out?


dataset = load_dataset("Abirate/english_quotes")
quotes = [q for q in dataset['train']['quote'] if (len(q) > 23 and len(q) < 140)]
#for q in quotes[0:10]:
    #print(q)

#rando = np.random.choice(quotes, 100, replace=False)
cleaned_quotes = [q.replace('“','').replace('”','') for q in quotes]

rando = random.choices(cleaned_quotes,k=100)

# Define constants
max_tokens = 115
desired_token_limit = 8192
batch_size = desired_token_limit // max_tokens

# Define the tokenizer with left-side padding
tokenizer = transformers.AutoTokenizer.from_pretrained(
    model_name_or_path, model_max_length=115,
    padding_side="left", use_fast=True,
    attn_implementation=attn_implementation
    # , add_eos_token=True, add_bos_token=True
)
tokenizer.pad_token = tokenizer.eos_token

# Position info about the interventions
share_weights = True # Whether the prefix and suffix interventions share weights.
positions = "f3+l3"  # The intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
first_n, last_n = pyreft.parse_positions(positions)

terminators = [tokenizer.eos_token_id]

def get_intervention_locations(last_position, first_n, last_n, pad_mode, num_interventions, share_weights):
    # Placeholder function for getting intervention locations, replace with actual logic
    return [[0] * last_position for _ in range(num_interventions)]

tokenized_prompts = []
# Preprocess: Split into sentences and tokenize
for q_ in range(len(rando)):
    sentences = sent_tokenize(rando[q_])
    for s_ in sentences:
        original_prompt = prompt_no_input_template % s_
        last_position = len(tokenizer.encode(original_prompt))  # Get actual length before padding
        tokenized_prompt = tokenizer(original_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=max_tokens)
        tokenized_prompts.append((q_, tokenized_prompt.to(device), last_position))

#Incorrect batch
if(True):
    #attempted batch

    # Process: Generate responses in batches
    quotes_fol_ = []

    for r in tqdm(range(0, len(tokenized_prompts), batch_size)):
        batch_prompts = tokenized_prompts[r: r + batch_size]

        input_ids = torch.cat([bp[1]['input_ids'] for bp in batch_prompts], dim=0)
        attention_masks = torch.cat([bp[1]['attention_mask'] for bp in batch_prompts], dim=0)

        unit_locations = torch.IntTensor([get_intervention_locations(
            last_position=max_tokens,
            first_n=first_n,
            last_n=last_n,
            pad_mode="last",
            num_interventions=len(reft_config.representations),
            share_weights=share_weights
        )]).repeat(input_ids.shape[0] // len(batch_prompts), 1, 1).permute(1, 0, 2).tolist()

        # Generate with beam search
        generation_args = {
            "base": {"input_ids": input_ids, "attention_mask": attention_masks},
            "unit_locations": {"sources->base": (None, unit_locations)},
            "intervene_on_prompt": True,
            "max_new_tokens": max_tokens,
            "do_sample": True,
            "top_k": 50,
            "temperature": 0.7,
            "num_beams": 5,
            "eos_token_id": terminators,
            "early_stopping": True
        }

        _, reft_response = reft_model.generate(**generation_args)

        responses = tokenizer.batch_decode(reft_response, skip_special_tokens=True)

        quotes_fol = []
        for i, response in enumerate(responses):
            quotes_fol.append(response)
            original_index = batch_prompts[i][0]
            quotes_fol_.append([original_index, quotes_fol])

    # Output the final results
    print(quotes_fol_)

that's what I got atm, but it's not applying the control vector appropriately