andyrdt / refusal_direction

Code and results accompanying the paper "Refusal in Language Models Is Mediated by a Single Direction".
https://arxiv.org/abs/2406.11717
Apache License 2.0
123 stars 25 forks source link

Currently not working with Gemma 2 models #4

Open DalasNoin opened 4 months ago

DalasNoin commented 4 months ago

I tried to run this with Gemma 2 27b it and found that it doesn't quite work. I verified that everything works with qwen/qwen-1_8b-chat.

I get this error message:

Assertion error: All scores have been filtered out

It also seems that KL scores are very large (>10)

I tried to find the reason but could not find a solution so far.

However, I did verify that the chat template worked correctly and it also seemed i could sample text from the model normally, when i placed a breakpoint in the function get_mean_activations which measured the activations.

What seemed odd was that the mean_diff of activations between harmful and harmless prompts was quite large, often between -200 and +200. In comparison, the mean diff of qwen was more like -2 to 2. So possibly there is an issue with the hooks?

The current GemmaModel is designed for Gemma 1 models. It seems the only architectural change is to add a rms norm before and after the MLP. I am not familiar with the details of the Gemma2RMSNorm implementation.

fblissjr commented 4 months ago

@DalasNoin

I made an attempt on a fork here, not very cleanly, using the Hugging Face local-gemma library. Ran out of time to try applying and testing anything, but maybe it will be useful to someone. Made a bunch of small edits to the pipeline code for observability reasons (starting with finding why NaNs and then working my way up).

fork: https://github.com/fblissjr/refusal_direction gemma2 model: https://github.com/fblissjr/refusal_direction/blob/main/pipeline/model_utils/gemma2_model.py

If anyone gets this working, would love to know what did it, and if it's more generalized than just refusals (since that's the code we have, it's where I started).

kl_div_scores

revmag commented 4 months ago

Yes, had the same issue. Mailed the co-author for this: The issue lies with filtering, the filtering function uses the tokenizer from the chat model, and not from the base model, as the base model's filters all the instructions out ( they score rather poorely).

prompts = chat_model.tokenize_instructions_fn(instructions=instructions[i:i+batch_size]).to('cuda') logits = chat_model.model(input_ids=prompts.input_ids, attention_mask=prompts.attention_mask).logits refusal_scores.append(refusal_score(logits, refusal_toks=chat_model.refusal_toks))

You can either first filter the instructions using chat model's tokenzier, and store it and then call it separately for the base model, or just add chat model's tokenizer for the filtering part

andyrdt commented 2 months ago

I just tried with gemma-2-2b-it, and it worked out of the box - I pushed the code and artifacts to the gemma2 branch.

The only issue I ran into was needing to specify attn_implementation=eager to avoid NaN issues. (see https://github.com/huggingface/transformers/issues/32390).

[I do not endorse the previous comment from @revmag about base/chat tokenizers - the code only works with chat models, and does not use base models anywhere.]

andyrdt commented 2 months ago

I also got gemma-2-9b-it working, although the direction selection algorithm as implemented doesn't arrive at a very good direction - I found best results by using the candidate direction from pos=-1, layer = 23.

For gemma-2-27b-it I reproduced @DalasNoin's initial observation of very high KL for basically all ablations. I find it weird that things are fairly smooth for 2b and 9b, but break with 27b - I think this requires deeper investigation.

dribnet commented 2 days ago

I found your gemma-2-2b-it results useful - should we merge the gemma2 branch into main to make it an offical part of this repo?