Open amanb2000 opened 4 months ago
Finished first version with https://github.com/amanb2000/Magic_Words/commit/2c3ee24277fa7b9ba711143c2e26e94751ceda3e
eval_batch_size
(how many training examples we use to compute the score of each alternative prompt)
Let's generalize the [easy_gcg]() code to optimize prompts on a dataset of
(x, y)
pairs, where eachx
is thequestion
andy
is theanswer
.We want to solve
u := argmax_u E [P(y | u + x)]
where the expectation is taken over the dataset(x, y) ~ D
.We can start by simply aggregating gradients for the swaps in GCG over multiple elements of the batch (https://github.com/amanb2000/Magic_Words/blob/32840cd867c83fc131205e5ff639a109f4e4f78c/magic_words/easy_gcg.py#L178).
All that remains is to create an efficient
batch_compute_score_dataset()
function to compute the scores of each potential new prompt w.r.t. the dataset (https://github.com/amanb2000/Magic_Words/blob/32840cd867c83fc131205e5ff639a109f4e4f78c/magic_words/easy_gcg.py#L263)