amanb2000 / Magic_Words

Code for the paper "What's the Magic Word? A Control Theory of LLM Prompting"
MIT License
90 stars 12 forks source link

Forward GCG #3

Closed amanb2000 closed 3 months ago

amanb2000 commented 9 months ago

Batch Dimension in forward_reachability.py

GCG Swaps to Suggest New Prompts

The central "magic trick" to the AutoPrompt family of prompt optimization methods is to leverage gradient information at the embedding layer. This information is used to influence the generation of new prompts. Generated prompts are compared on the objective (loss) function, and the best ones are retained for the next generation.

Usually, people supply a loss function that measures the conditional likelihood the LLM assigns to some set of future token sequences (e.g., LLM-Attacks).

We wish to characterize the reachable set R from some sequence of state tokens x_0 (cf Definition 3: Reachability). To do so, we must attempt to generate the set of target states y* for which there exists a control input u which carries the system from state x_0 to output y*.

A foothold for finding this set is to start by enumerating control inputs u and examining the resulting outputs y from a system initialized with state x_0. For LLM systems, we may enumerate all possible single token control inputs u in [0, vocab_size).

Further enumeration is computationally intractable, however, as most LLMs' vocabulary size (number of possible token identities) is on the order of 10,000. So we are stuck with the question: how do I most efficiently pick the control input sequences to try beyond 1-token sequences? Random guessing might get us somewhere, but we want to take advantage of whatever information we can to inform our next guesses.

You can apply AutoPrompt methods to the problem as soon as you have a differentiable objective function that maps from a set of input tokens embeddings to a loss to minimize. We can frame "exploration from the existing reachable set" as a differentiable loss function via the divergence between two distributions: to represent the existing reachable set, we use uniform({reachable set}), and to represent the would-be reachable set with the new control input we can use P(y | u + x).

From there, we can apply a similar iterative generate-test-cull as in GCG, but we randomly pick batch_size of the prompts currently in the reachable_df. We should also have some min_growth which is how many of the best swap suggestions we should still add to the dataframe -- i.e., we add more of the top swaps to get to this number of new tokens in the reachable_df every iteration.

We can use lots of the same code as in easy_gcg.py. It boils down to the following pseudocode:

function forward_gcg(x_0, model, tokenizer, top_k, 
                    num_prompt_tokens=10, 
                    batch_size=768, 
                    num_iters=34, 
                    max_parallel=300)
   # initialize reachable_df with a random num_prompt_token sequence

   for iter in 1:num_iters
      select a random prompt u from reachable_df['prompt_ids']
      compute the CE loss between P(y | u + x_0) and the uniform(R_t), storing the 1-hot embedding gradients. 
      sample `batch_size` alternative prompt_ids 
      compute the answer_ids for all the new prompt_ids  # experiment with a cutoff value so we only test the top few -- that way, we can sample more in the previous step(?). 
      add the alternative prompt_ids that yielded novel answer_ids w.r.t. reachable_df 
   end for

   # Hopefully we've covered lots of the space :) 
   # we can add a 'purgatory' pool to breed more diverse prompts in the future 

   return reachable_df

end function
amanb2000 commented 9 months ago

Other Forward Generation Methods to Try

Random Tree Search

To get started with more sophisticated heuristic-guided search algorithms, we will build a search algorithm with no heuristic -- a uniform random explorer of the tree of possible next tokens. The goal is to minimize the number of parameters required to express the computation here.

Eviction Parameter eps_e

For non-novel (u, y*), there is probability eps_e that we "let it through" and replace the current (u, y*) in the database.

Add Divergence Cutoff for eps_e exploration

Let's let in only some of the eps_e randomly selected non-novel triplets. Specifically, we will let in only the top --divergence_fraction of them, where divergence is measured as the distance between R_t (the reachable set of y* tokens at time t) and the next token logits P(y | u + x_0).

amanb2000 commented 9 months ago

4: Add Divergence Cutoff for eps_e exploration

Let's let in only some of the eps_e randomly selected non-novel triplets. Specifically, we will let in only the top --divergence_fraction of them, where divergence is measured as the distance between R_t (the reachable set of y* tokens at time t) and the next token logits P(y | u + x_0).

amanb2000 commented 9 months ago

5: GCG Swaps to Suggest New Prompts

Let's add a function called _mutate_gcg(prompt_ids, question_ids, R_t, **kwargs) that accepts the usual GCG swap generation arguments and suggests its top picks based on prompt_ids and the first-order approximations of the loss on R_t associated with making a given swap. Note that we want to maximize R_t!!