Open thistleknot opened 3 months ago
@thistleknot Yes, it supports batched inference calls.
You can take a look at this function for batching: https://github.com/stanfordnlp/pyreft/blob/main/examples/loreft/compute_metrics.py#L111
In a nutshell, you need to apply left padding to your tokenizer and calculate the batched intervention locations accordingly.
'calculate the batched intervention locations accordingly.'
that doesn't sound easy.
I'm not sure if I can use the same -1 position as I was before for each prompt... or if it's expecting it to be where it is within the batch tensor.
you able to help a brother out?
dataset = load_dataset("Abirate/english_quotes")
quotes = [q for q in dataset['train']['quote'] if (len(q) > 23 and len(q) < 140)]
#for q in quotes[0:10]:
#print(q)
#rando = np.random.choice(quotes, 100, replace=False)
cleaned_quotes = [q.replace('“','').replace('”','') for q in quotes]
rando = random.choices(cleaned_quotes,k=100)
# Define constants
max_tokens = 115
desired_token_limit = 8192
batch_size = desired_token_limit // max_tokens
# Define the tokenizer with left-side padding
tokenizer = transformers.AutoTokenizer.from_pretrained(
model_name_or_path, model_max_length=115,
padding_side="left", use_fast=True,
attn_implementation=attn_implementation
# , add_eos_token=True, add_bos_token=True
)
tokenizer.pad_token = tokenizer.eos_token
# Position info about the interventions
share_weights = True # Whether the prefix and suffix interventions share weights.
positions = "f3+l3" # The intervening positions of prefix tokens (f[irst]1) and suffix tokens (l[ast]1).
first_n, last_n = pyreft.parse_positions(positions)
terminators = [tokenizer.eos_token_id]
def get_intervention_locations(last_position, first_n, last_n, pad_mode, num_interventions, share_weights):
# Placeholder function for getting intervention locations, replace with actual logic
return [[0] * last_position for _ in range(num_interventions)]
tokenized_prompts = []
# Preprocess: Split into sentences and tokenize
for q_ in range(len(rando)):
sentences = sent_tokenize(rando[q_])
for s_ in sentences:
original_prompt = prompt_no_input_template % s_
last_position = len(tokenizer.encode(original_prompt)) # Get actual length before padding
tokenized_prompt = tokenizer(original_prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=max_tokens)
tokenized_prompts.append((q_, tokenized_prompt.to(device), last_position))
#Incorrect batch
if(True):
#attempted batch
# Process: Generate responses in batches
quotes_fol_ = []
for r in tqdm(range(0, len(tokenized_prompts), batch_size)):
batch_prompts = tokenized_prompts[r: r + batch_size]
input_ids = torch.cat([bp[1]['input_ids'] for bp in batch_prompts], dim=0)
attention_masks = torch.cat([bp[1]['attention_mask'] for bp in batch_prompts], dim=0)
unit_locations = torch.IntTensor([get_intervention_locations(
last_position=max_tokens,
first_n=first_n,
last_n=last_n,
pad_mode="last",
num_interventions=len(reft_config.representations),
share_weights=share_weights
)]).repeat(input_ids.shape[0] // len(batch_prompts), 1, 1).permute(1, 0, 2).tolist()
# Generate with beam search
generation_args = {
"base": {"input_ids": input_ids, "attention_mask": attention_masks},
"unit_locations": {"sources->base": (None, unit_locations)},
"intervene_on_prompt": True,
"max_new_tokens": max_tokens,
"do_sample": True,
"top_k": 50,
"temperature": 0.7,
"num_beams": 5,
"eos_token_id": terminators,
"early_stopping": True
}
_, reft_response = reft_model.generate(**generation_args)
responses = tokenizer.batch_decode(reft_response, skip_special_tokens=True)
quotes_fol = []
for i, response in enumerate(responses):
quotes_fol.append(response)
original_index = batch_prompts[i][0]
quotes_fol_.append([original_index, quotes_fol])
# Output the final results
print(quotes_fol_)
that's what I got atm, but it's not applying the control vector appropriately
I'm doing this atm
but i'd like to escape the iteration, and I'm not sure how to format unit_locations. Normally one would do something like model.generate(**inputs), but this being pyreft, I'm not sure if that is supported as it's a custom class (I haven't delved into the class for this specific feature).
Thought I'd ask first as well as for visibility for others who might be interested.