Closed amanb2000 closed 3 months ago
To get started with more sophisticated heuristic-guided search algorithms, we will build a search algorithm with no heuristic -- a uniform random explorer of the tree of possible next tokens. The goal is to minimize the number of parameters required to express the computation here.
--num_iters
that defaults to -1 (i.e., do only the 1-token brute force strategy). If num_iters > 0
, then we do the iterative tree growth. _get_prompt_ids_rand_tree(num_prompts, reachable_df, dud_set, max_tokens=10)
: Randomly select some items from reachable_df
that have length less than max_tokens
and add a random token. Double check that the strings are not already in the reachable_df
. Add a 'parent'
column to reachable_df
with the prompt_ids
of the parent (if any parent exists). eps_e
For non-novel (u, y*)
, there is probability eps_e
that we "let it through" and replace the current (u, y*)
in the database.
--eps_e
. _ingest_eps()
: Function that adds the new [prompt_ids, question_ids, answer_ids]
triplets to the Dataframe -- with probability eps_e
, it will let a non-novel triplet through. eps_e
explorationLet's let in only some of the eps_e
randomly selected non-novel triplets. Specifically, we will let in only the top --divergence_fraction
of them, where divergence is measured as the distance between R_t
(the reachable set of y*
tokens at time t) and the next token logits P(y | u + x_0)
.
--eps_e
argparse param, input to get_reachable_set
, forward_generate
. get_reachable_logits(reachable_df)
: Returns a tensor of shape [vocab_size]
with 1's at the token ids index that we have reached (normalized to prob dist)get_divergence(R_t, new_logits)
: R_t
has shape [vocab_size]
, new_logits
has shape [batch, vocab_size]
. This function computes the divergence between each batch element in new_logits
and R_t
. _ingest_eps()
to accept the divergence_cutoff
and measure this for the eps_e
instances that come through the barrier. eps_e
explorationLet's let in only some of the eps_e
randomly selected non-novel triplets. Specifically, we will let in only the top --divergence_fraction
of them, where divergence is measured as the distance between R_t
(the reachable set of y*
tokens at time t) and the next token logits P(y | u + x_0)
.
--eps_e
argparse param, input to get_reachable_set
, forward_generate
. get_reachable_logits(reachable_df)
: Returns a tensor of shape [vocab_size]
with 1's at the token ids index that we have reached (normalized to prob dist)get_divergence(R_t, new_logits)
: R_t
has shape [vocab_size]
, new_logits
has shape [batch, vocab_size]
. This function computes the divergence between each batch element in new_logits
and R_t
. _ingest_eps()
to accept the divergence_cutoff
and measure this for the eps_e
instances that come through the barrier. Let's add a function called _mutate_gcg(prompt_ids, question_ids, R_t, **kwargs)
that accepts the usual GCG swap generation arguments and suggests its top picks based on prompt_ids
and the first-order approximations of the loss on R_t associated with making a given swap. Note that we want to maximize R_t!!
Batch Dimension in
forward_reachability.py
forward_generate
,get_reachable_set
to themagic_words
package._get_prompt_ids_brute_force()
: Return a tensor of shape [vocab_size, 1] of all the new prompt tokens we want to test out. In the future, simply add a newget_prompt_ids_YOUR_NAME_HERE()
function to build up functionality._get_answer_ids()
: Given aninput_ids
tensor of shape[batch, num_ids]
(which has a bunch of concatenated[prompt] + [x_0]
values as rows), this function returns a tensor of shape[batch, 1]
with the maximum likelihood next token for eachinput_ids
row. Accepts a parametermax_parallel=300
for the sub-batch size._naive_ingest()
: Given the current Dataframe and the newanswer_ids
tensor and its parentquestion_ids
andprompt_ids
, this function will add novel(u, y*)
pairs that reach an element of the output space that has yet to be reached in the Dataframe. Acceptsbase_logits
for base loss computation for novely*
added to the Dataframe. Make sure to re-compute the logits for the "success" cases -- we don't have the memory to save all the logits.GCG Swaps to Suggest New Prompts
The central "magic trick" to the AutoPrompt family of prompt optimization methods is to leverage gradient information at the embedding layer. This information is used to influence the generation of new prompts. Generated prompts are compared on the objective (loss) function, and the best ones are retained for the next generation.
Usually, people supply a loss function that measures the conditional likelihood the LLM assigns to some set of future token sequences (e.g., LLM-Attacks).
We wish to characterize the reachable set
R
from some sequence of state tokensx_0
(cf Definition 3: Reachability). To do so, we must attempt to generate the set of target statesy*
for which there exists a control inputu
which carries the system from statex_0
to outputy*
.A foothold for finding this set is to start by enumerating control inputs
u
and examining the resulting outputsy
from a system initialized with statex_0
. For LLM systems, we may enumerate all possible single token control inputsu in [0, vocab_size)
.Further enumeration is computationally intractable, however, as most LLMs' vocabulary size (number of possible token identities) is on the order of 10,000. So we are stuck with the question: how do I most efficiently pick the control input sequences to try beyond 1-token sequences? Random guessing might get us somewhere, but we want to take advantage of whatever information we can to inform our next guesses.
You can apply AutoPrompt methods to the problem as soon as you have a differentiable objective function that maps from a set of input tokens embeddings to a loss to minimize. We can frame "exploration from the existing reachable set" as a differentiable loss function via the divergence between two distributions: to represent the existing reachable set, we use
uniform({reachable set})
, and to represent the would-be reachable set with the new control input we can useP(y | u + x)
.From there, we can apply a similar iterative generate-test-cull as in GCG, but we randomly pick
batch_size
of the prompts currently in thereachable_df
. We should also have somemin_growth
which is how many of the best swap suggestions we should still add to the dataframe -- i.e., we add more of the top swaps to get to this number of new tokens in the reachable_df every iteration.We can use lots of the same code as in
easy_gcg.py
. It boils down to the following pseudocode: