Open DalasNoin opened 4 months ago
@DalasNoin
I made an attempt on a fork here, not very cleanly, using the Hugging Face local-gemma
library. Ran out of time to try applying and testing anything, but maybe it will be useful to someone. Made a bunch of small edits to the pipeline code for observability reasons (starting with finding why NaNs and then working my way up).
fork: https://github.com/fblissjr/refusal_direction gemma2 model: https://github.com/fblissjr/refusal_direction/blob/main/pipeline/model_utils/gemma2_model.py
If anyone gets this working, would love to know what did it, and if it's more generalized than just refusals (since that's the code we have, it's where I started).
Yes, had the same issue. Mailed the co-author for this: The issue lies with filtering, the filtering function uses the tokenizer from the chat model, and not from the base model, as the base model's filters all the instructions out ( they score rather poorely).
prompts = chat_model.tokenize_instructions_fn(instructions=instructions[i:i+batch_size]).to('cuda') logits = chat_model.model(input_ids=prompts.input_ids, attention_mask=prompts.attention_mask).logits refusal_scores.append(refusal_score(logits, refusal_toks=chat_model.refusal_toks))
You can either first filter the instructions using chat model's tokenzier, and store it and then call it separately for the base model, or just add chat model's tokenizer for the filtering part
I just tried with gemma-2-2b-it
, and it worked out of the box - I pushed the code and artifacts to the gemma2
branch.
The only issue I ran into was needing to specify attn_implementation=eager
to avoid NaN issues. (see https://github.com/huggingface/transformers/issues/32390).
[I do not endorse the previous comment from @revmag about base/chat tokenizers - the code only works with chat models, and does not use base models anywhere.]
I also got gemma-2-9b-it
working, although the direction selection algorithm as implemented doesn't arrive at a very good direction - I found best results by using the candidate direction from pos=-1, layer = 23
.
For gemma-2-27b-it
I reproduced @DalasNoin's initial observation of very high KL for basically all ablations. I find it weird that things are fairly smooth for 2b
and 9b
, but break with 27b
- I think this requires deeper investigation.
I found your gemma-2-2b-it
results useful - should we merge the gemma2
branch into main to make it an offical part of this repo?
I tried to run this with Gemma 2 27b it and found that it doesn't quite work. I verified that everything works with qwen/qwen-1_8b-chat.
I get this error message:
Assertion error: All scores have been filtered out
It also seems that KL scores are very large (>10)
I tried to find the reason but could not find a solution so far.
However, I did verify that the chat template worked correctly and it also seemed i could sample text from the model normally, when i placed a breakpoint in the function get_mean_activations which measured the activations.
What seemed odd was that the mean_diff of activations between harmful and harmless prompts was quite large, often between -200 and +200. In comparison, the mean diff of qwen was more like -2 to 2. So possibly there is an issue with the hooks?
The current GemmaModel is designed for Gemma 1 models. It seems the only architectural change is to add a rms norm before and after the MLP. I am not familiar with the details of the Gemma2RMSNorm implementation.