Sumandora / remove-refusals-with-transformers

Implements harmful/harmless refusal removal using pure HF Transformers
Apache License 2.0
25 stars 5 forks source link

Modified to save in `safetensors` format now #1

Open jukofyork opened 5 months ago

jukofyork commented 5 months ago

Just want to say thanks for this! I've been trying to use other peoples' code that all use the transformer_lens library and it has a bug that stops you loading models in 4bit, seems to have loads of problems with mixed 'cpu' and 'cuda' tensors, and is generally really slow for some reason.

I've modified your code to:

Only tested on Mistral-7B-Instruct-v0.2 and miqu-1-70b-sf, which both use llama tensor names, but can confirm it is working.

import torch
import gc
import random

from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
from tqdm import tqdm

#MODEL_ID = "Mistral-7B-Instruct-v0.2"
MODEL_ID = "miqu-1-70b-sf"

# More samples can help find the direction better.
NUM_PROMPT_SAMPLES = 32

# Used to skip the first and last layers for the modifications.
SKIP_BEGIN_LAYERS = 1  # Don't mess with the first layer.
SKIP_END_LAYERS = 0

# The layer we will use for the refusal_dir calculation will be floor(LAYER_FRACTION_TO_USE * model.layers).
LAYER_FRACTION_TO_USE = 0.6

# Use a negative scale_factor to "induce" and a positive scale_factor of < 1 to "ablate" less.
SCALE_FACTOR = 1.0

torch.inference_mode()
torch.set_default_device("cpu")
torch.set_grad_enabled(False)

# Load the model on the GPU in quantized type if we can.
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.float16,
    quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
    low_cpu_mem_usage=True,
    device_map='auto'
)
model.requires_grad_(False)

tokenizer = AutoTokenizer.from_pretrained(MODEL_ID, trust_remote_code=True)

layer_idx = int(len(model.model.layers) * LAYER_FRACTION_TO_USE)
print("Layer index for refusal direction: " + str(layer_idx))

with open("harmful.txt", "r") as f:
    harmful = f.readlines()

with open("harmless.txt", "r") as f:
    harmless = f.readlines()

harmful_instructions = random.sample(harmful, min(NUM_PROMPT_SAMPLES, len(harmful)))
harmless_instructions = random.sample(harmless, min(NUM_PROMPT_SAMPLES, len(harmless)))

harmful_toks = [
    tokenizer.apply_chat_template(conversation=[{"role": "user", "content": insn}], add_generation_prompt=True,
                                  return_tensors="pt") for insn in harmful_instructions]
harmless_toks = [
    tokenizer.apply_chat_template(conversation=[{"role": "user", "content": insn}], add_generation_prompt=True,
                                  return_tensors="pt") for insn in harmless_instructions]

bar_generate = tqdm(total = len(harmful_instructions) + len(harmless_instructions), desc = "Generating samples")

# Only return the final hidden state of the layer we care about, and use 'cpu' to save VRAM.
def generate(toks):
    output = model.generate(
        toks.to(model.device),
        use_cache=False,
        max_new_tokens=1,
        return_dict_in_generate=True,
        output_hidden_states=True,
        pad_token_id=tokenizer.eos_token_id
    )
    bar_generate.update(n=1)
    return output.hidden_states[0][layer_idx][:, -1, :].to('cpu') # Final hidden state = -1.

harmful_hidden = [generate(toks) for toks in harmful_toks]
harmless_hidden = [generate(toks) for toks in harmless_toks]

bar_generate.close()

harmful_mean = torch.stack(harmful_hidden).mean(dim=0)
harmless_mean = torch.stack(harmless_hidden).mean(dim=0)

refusal_dir = harmful_mean - harmless_mean
refusal_dir = refusal_dir.squeeze() / refusal_dir.norm()

torch.save(refusal_dir, MODEL_ID.replace("/", "_") + "_refusal_dir.pt")

# Free memory
del model
gc.collect()
torch.cuda.empty_cache()

# Reload the model in CPU memory with bfloat16 data type
model = AutoModelForCausalLM.from_pretrained(
    MODEL_ID,
    trust_remote_code=True,
    torch_dtype=torch.bfloat16,
    low_cpu_mem_usage=True,
    device_map='cpu'
)
model.requires_grad_(False)

# Make sure it's on the 'cpu' device.
if refusal_dir.device != model.device:
    refusal_dir = refusal_dir.to(model.device)

# Get the language model component and check it's as expected.
lm_model = model.model
assert hasattr(lm_model, 'layers'), "The model does not have the expected structure."

# Check the ranges are valid.
num_layers = len(lm_model.layers)
assert SKIP_BEGIN_LAYERS >= 0, "SKIP_BEGIN_LAYERS must be >= 0."
assert SKIP_END_LAYERS >= 0, "SKIP_END_LAYERS must be >= 0."
assert SKIP_BEGIN_LAYERS + SKIP_END_LAYERS < num_layers, "SKIP_BEGIN_LAYERS + SKIP_END_LAYERS must be < num_layers."

bar_layers = tqdm(total= (num_layers - (SKIP_BEGIN_LAYERS + SKIP_END_LAYERS)) * 2, desc = "Modifying tensors")

# Cast any ops performed on CPU up to float32... If you have newer CPU might be able to use bfloat16 for this.
# NOTE: Use a negative scale_factor to "induce" and a positive scale_factor of < 1 to "ablate" less.
def modify_tensor(tensor_data, refusal_dir, scale_factor: float = 1.0):
    assert scale_factor <= 1.0, "Using a scale_factor of > 1 doesn't make sense..."
    tensor_float32 = tensor_data.to(torch.float32)
    refusal_dir_float32 = refusal_dir.to(torch.float32)
    tensor_float32 -= scale_factor * torch.matmul(torch.outer(refusal_dir_float32, refusal_dir_float32), tensor_float32)
    tensor_modified = tensor_float32.to(torch.bfloat16)
    bar_layers.update(1)
    return torch.nn.Parameter(tensor_modified)

# Modify the 'self_attn.o_proj.weight' and 'mlp.down_proj.weight' in each chosen layer.
# NOTE: These tensors names are speific to "llama" and may need changing.
#       - See here for others: https://github.com/arcee-ai/mergekit/tree/main/mergekit/_data/architectures
for layer_idx in range(SKIP_BEGIN_LAYERS, num_layers - SKIP_END_LAYERS):
    lm_model.layers[layer_idx].self_attn.o_proj.weight = modify_tensor(
        lm_model.layers[layer_idx].self_attn.o_proj.weight.data, refusal_dir, SCALE_FACTOR
    )
    lm_model.layers[layer_idx].mlp.down_proj.weight = modify_tensor(
        lm_model.layers[layer_idx].mlp.down_proj.weight.data, refusal_dir, SCALE_FACTOR
    )

bar_layers.close()

# Save the modified model and original tokenizer
print("Saving modified model (with original tokenizer)...")
#model.save_pretrained("Mistral-7B-Instruct-v0.2-fixed")
#tokenizer.save_pretrained("Mistral-7B-Instruct-v0.2-fixed")
model.save_pretrained("miqu-1-70b-sf-fixed")
tokenizer.save_pretrained("miqu-1-70b-sf-fixed")

Beware that I have squeezed refusal_dir back to a vector from a (1, d_hidden) tensor, so you might need to change your inference.py code to match:

refusal_dir = refusal_dir.squeeze() / refusal_dir.norm()

torch.save(refusal_dir, MODEL_ID.replace("/", "_") + "_refusal_dir.pt")

If you want I can tidy the code up and do a proper pull request or otherwise feel free to copy in whatever bits you might find useful - my runtime for miqu has gone from several hours (and around 500-600GB RAM needed!) to a few minutes thanks to the 4bit stuff working, so huge thanks again!


I'm actually trying to use this method to remove some of the "positivity" from creative writing models rather than remove refusals, so I will likely be making a lot more modifications now I have the bare bones code working... I suspect my reason for failure so far is the use of max_new_tokens=1 - this probably works well for refusals as the first word being "sorry" or "sure" is quite telling (the Mopey-Mule model also seems to always start its reply with "*sigh*" too), but for creative writing the first word used has little bearing on whether the text is going to be "positive" or "dark", etc.


EDIT: I just changed the code a bit more so generate() only returns the hidden state of the layer we care about (and on the 'cpu' device), as otherwise it was saving all layers for all the samples and ended up causing a CUDA OOM error after using up all available VRAM when you increase NUM_PROMPT_SAMPLES to much higher values.

jukofyork commented 5 months ago

These are the timings for miqu-1-70b using 512 harmless and 512 harmful samples if anyone is interested:

Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:32<00:00,  1.10s/it]
Layer index for refusal direction: 48
Generating samples: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1024/1024 [09:29<00:00,  1.80it/s]
Loading checkpoint shards: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 29/29 [00:16<00:00,  1.77it/s]
Modifying tensors: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 158/158 [03:00<00:00,  1.14s/it]
Saving modified model (with original tokenizer)...

It is definitely working and now happily answering questions about bank robbing, making meth, and so on...

I'm not 100% sure if the self_attn.o_proj.weight tensors really need to be modified, as it seems to work nearly as well modifying the mlp.down_proj.weight tensors alone...

I have dual A6000s (= 96GB VRAM in total), but this could have been done with 48GB VRAM too. With 96GB VRAM I should (in theory) be able to run this on the larger models like Mixtral-8x22b-Instruct, but sadly I have terrible upload speed so can't really upload it... Somebody else can try that (you'll need to look up the name of the correct MoE MLP tensors as IIRC, it's not the same).

Sumandora commented 5 months ago

Thanks, I like the changes you made, but I don't think they fit into this repository. This is meant to be a crude proof-of-concept and I don't want to implement even more model specific stuff. Maybe adding something along the lines of a export_llama.py script is a good idea, but for now I will just leave the issue open for people to use your modified version if they prefer it. Maybe the repository you mentioned here is more destined to reuse my transformer implementation. It is MIT licensed after all. Thanks for the contribution regardless!

jukofyork commented 5 months ago

Thanks, I like the changes you made, but I don't think they fit into this repository. This is meant to be a crude proof-of-concept and I don't want to implement even more model specific stuff. Maybe adding something along the lines of a export_llama.py script is a good idea, but for now I will just leave the issue open for people to use your modified version if they prefer it. Maybe the repository you mentioned here is more destined to reuse my transformer implementation. It is MIT licensed after all. Thanks for the contribution regardless!

No problem and I agree it is nice to have a simple "proof of concept" code to work from like this.

If I get anywhere with reducing creative writing "positivity" then I'll post an update, but just happy that I can iterate though ideas so much easier now - thanks again!

jukofyork commented 5 months ago

@Sumandora

I've already solved the creative writing "positivity" problem thanks to this code - I spent nearly a week on this before and got nowhere...

If you sent me your paypal address to my username here with yahoo and com on the end, I would like to send you a donation for making this code available. Please add a randomized string to the email and I will paste that back here to confirm it really is you first though! :)

Sumandora commented 5 months ago

If you sent me your paypal address to my username here with yahoo and com on the end, I would like to send you a donation for making this code available. Please add a randomized string to the email and I will paste that back here to confirm it really is you first though! :)

Thanks. That is very kind, but I don't take donations for my open source projects. (At least not now). Thanks for your support anyways!

jukofyork commented 5 months ago

If you sent me your paypal address to my username here with yahoo and com on the end, I would like to send you a donation for making this code available. Please add a randomized string to the email and I will paste that back here to confirm it really is you first though! :)

Thanks. That is very kind, but I don't take donations for my open source projects. (At least not now). Thanks for your support anyways!

No problem :)

I'll post the code here then explain the maths in that other thread:

import torch
import gc
import random
import argparse

from transformers import AutoTokenizer, AutoModelForCausalLM, TextStreamer, BitsAndBytesConfig
from tqdm import tqdm

def main(model_id, output_path):

    # Set to zero to just sample 1 token per prompt, otherwise sample from 1 + Min[Floor[ExponentialDistribution[1 / MEAN_EXTRA]], MAX_EXTRA].
    MEAN_EXTRA_TOKENS_TO_GENERATE = 64
    MAX_EXTRA_TOKENS_TO_GENERATE = 4 * MEAN_EXTRA_TOKENS_TO_GENERATE

    # More samples can help find the direction better.
    NUM_PROMPT_SAMPLES = 1024

    # Used to skip the first and last layers for the modifications.
    SKIP_BEGIN_LAYERS = 0
    SKIP_END_LAYERS = 0

    # Used to skip modifying the attention or MLP tensors.
    SKIP_ATTN = True
    SKIP_MLP = False

    torch.inference_mode()
    torch.set_default_device("cpu")
    torch.set_grad_enabled(False)

    # Load the model on the GPU in quantized type if we can.
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        torch_dtype=torch.float16,
        quantization_config=BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_compute_dtype=torch.float16),
        low_cpu_mem_usage=True,
        device_map='auto',
        attn_implementation="flash_attention_2",
    )
    model.requires_grad_(False)

    tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)

    # Number of layers
    num_layers = len(model.model.layers)

    baseline_system_prompt = "When tasked with anything other than writing a story you should act and reply as normal, but your stories should be nuanced, realistic tales that feature complex, relatable characters who face challenges and must navigate the gray areas of life, experiencing a mix of successes and failures. The endings are thought-provoking and open-ended, showcasing the characters' growth, change, and the ambiguity of the human experience. Infuse your narratives with themes of self-discovery, the complexities of relationships, the pursuit of dreams, and the everyday struggles and joys of life."
    desired_system_prompt = "When tasked with anything other than writing a story you should act and reply as normal, but your stories should be unsettling, bleak tales that feature flawed, frail characters who face challenges and often succumb to their darkest nature, ultimately broken by their struggles. The endings are invariably dark and hopeless, showcasing negative character transformations, inescapable fate, and the defeat of the human spirit. Infuse your narratives with themes of despair, betrayal, toxic obsession, moral decay, and the sinister undercurrents of everyday life."
    undesired_system_prompt = "When tasked with anything other than writing a story you should act and reply as normal, but your stories should be uplifting, heartwarming tales that feature kind, resilient characters who face challenges but always prevail through their virtuous nature, courage, and the power of human connection. The endings are invariably happy and hopeful, showcasing positive character growth, second chances, and the triumph of the human spirit. Infuse your narratives with themes of joy, empathy, unconditional love, chasing dreams, and the beauty of everyday magic."

    with open("prompts.txt", "r") as f:
        prompts = f.readlines()

    prompts = random.sample(prompts, min(NUM_PROMPT_SAMPLES, len(prompts)))

    baseline_token_lists = [
        tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": baseline_system_prompt + " " + prompt}],
            add_generation_prompt=True,
            return_tensors="pt") for prompt in prompts
    ]
    desired_token_lists = [
        tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": desired_system_prompt + " " + prompt}],
            add_generation_prompt=True,
            return_tensors="pt") for prompt in prompts
    ]
    undesired_token_lists = [
        tokenizer.apply_chat_template(
            conversation=[{"role": "user", "content": undesired_system_prompt + " " + prompt}],
            add_generation_prompt=True,
            return_tensors="pt") for prompt in prompts
    ]

    bar_generate = tqdm(total = 3 * len(prompts), desc = "Generating samples")

    def generate(tokens, max_new_tokens):
        output = model.generate(
            tokens.to(model.device),
            use_cache= True if max_new_tokens > 1 else False,
            max_new_tokens=max_new_tokens,
            return_dict_in_generate=True,
            output_hidden_states=True,
            pad_token_id=tokenizer.pad_token_id if tokenizer.pad_token_id is not None else tokenizer.eos_token_id
        )

        """
        for generated_token_index, hidden_state in enumerate(output.hidden_states):
            for i, decoder_element in enumerate(hidden_state):
                print(f"Generated token index: {generated_token_index}, decoder element {i} shape: {decoder_element.shape}")
        """

        # NOTE: `hidden_state[:, -1, :]` gets the last hidden state for the batch of tokens generated (ie: batch = 1 for our case, but 1st prompt eval will make [1] dim > 1).
        # NOTE: `hidden_states[-1]` gets the last hidden state of the last token generated at index of [max_new_tokens-1] (ie: [0] if max_new_tokens=1).
        # NOTE: `hidden_states[-1][1:]` gets only the hidden states *AFTER* an attention/MLP block. The [0] hidden state is *BEFORE* the first attention/MLP block...
        hidden_states_by_layer = [hidden_state[:, -1, :].squeeze().to('cpu') for hidden_state in output.hidden_states[-1][1:]]
        bar_generate.update(n=1)
        return hidden_states_by_layer

    baseline_hidden = []
    desired_hidden = []
    undesired_hidden = []

    for baseline_tokens, desired_tokens, undesired_tokens in zip(baseline_token_lists, desired_token_lists, undesired_token_lists):
        max_new_tokens = 1
        if MEAN_EXTRA_TOKENS_TO_GENERATE > 0:
            max_new_tokens += min(int(random.expovariate(1.0/MEAN_EXTRA_TOKENS_TO_GENERATE)), MAX_EXTRA_TOKENS_TO_GENERATE)
        baseline_hidden.append(generate(baseline_tokens, max_new_tokens))
        desired_hidden.append(generate(desired_tokens, max_new_tokens))
        undesired_hidden.append(generate(undesired_tokens, max_new_tokens))

    # Transpose the lists to access by layer
    baseline_hidden = list(zip(*baseline_hidden))
    desired_hidden = list(zip(*desired_hidden))
    undesired_hidden = list(zip(*undesired_hidden))

    bar_generate.close()

    householder_vectors = []

    # Compute the Householder vectors.
    for layer_index in range(num_layers):
        baseline_mean = torch.stack(baseline_hidden[layer_index]).mean(dim=0)
        desired_mean = torch.stack(desired_hidden[layer_index]).mean(dim=0)
        undesired_mean = torch.stack(undesired_hidden[layer_index]).mean(dim=0)
        desired_direction = desired_mean - baseline_mean
        undesired_direction = undesired_mean - baseline_mean
        difference_vector = undesired_direction - desired_direction
        householder_vector = difference_vector / difference_vector.norm()

        print(f"Layer {layer_index + 1}/{num_layers}:")
        direction_similarity = torch.nn.functional.cosine_similarity(desired_direction, undesired_direction, dim=0)
        print(f"- Cosine similarity between desired_direction and undesired_direction: {direction_similarity}")
        if layer_index > 0:
            householder_similarity = torch.nn.functional.cosine_similarity(householder_vector, householder_vectors[-1], dim=0)
            print(f"- Cosine similarity between current householder_vector and previous householder_vector: {householder_similarity}")
        print()

        householder_vectors.append(householder_vector)

    # Free memory
    del model
    gc.collect()
    torch.cuda.empty_cache()

    # Reload the model in CPU memory with bfloat16 data type
    model = AutoModelForCausalLM.from_pretrained(
        model_id,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        low_cpu_mem_usage=True,
        device_map='cpu'
    )
    model.requires_grad_(False)

    # Get the language model component and check it's as expected.
    lm_model = model.model
    assert hasattr(lm_model, 'layers'), "The model does not have the expected structure."

    # Check the ranges are valid.
    assert SKIP_BEGIN_LAYERS >= 0, "SKIP_BEGIN_LAYERS must be >= 0."
    assert SKIP_END_LAYERS >= 0, "SKIP_END_LAYERS must be >= 0."
    assert SKIP_BEGIN_LAYERS + SKIP_END_LAYERS < num_layers, "SKIP_BEGIN_LAYERS + SKIP_END_LAYERS must be < num_layers."

    bar_tensors = tqdm(total= (num_layers - (SKIP_BEGIN_LAYERS + SKIP_END_LAYERS)) * (SKIP_ATTN + SKIP_MLP), desc = "Modifying tensors")

    # By performing a (left-only) Householder transformation we reflect the matrix in the row space (ie: the linear weighted sums / "units").
    # NOTE: Down cast back to bfloat16 to save out in the same format as the un-modified tensors.
    def modify_tensor(weight_matrix, householder_matrix):
        weight_matrix = torch.matmul(householder_matrix, weight_matrix).to(torch.bfloat16)
        bar_tensors.update(1)
        return torch.nn.Parameter(weight_matrix)

    # Modify the 'self_attn.o_proj.weight' and 'mlp.down_proj.weight' in each chosen layer.
    # NOTE: These tensors names are speific to "llama" and may need changing.
    #       - See here for others: https://github.com/arcee-ai/mergekit/tree/main/mergekit/_data/architectures
    for layer_index in range(SKIP_BEGIN_LAYERS, num_layers - SKIP_END_LAYERS):

        # Ensure the householder vector is on the correct device and in float32 precision
        householder_vector = householder_vectors[layer_index].to(torch.float32)
        if householder_vector.device != model.device:
            householder_vector = householder_vector.to(model.device)

        # Calculate the Householder matrix for this layer in float32 precision
        identity_matrix = torch.eye(householder_vector.size(0), dtype=torch.float32)
        outer_product_matrix = torch.outer(householder_vector, householder_vector)
        householder_matrix = identity_matrix - 2 * outer_product_matrix

        # Modify this layer's attention projection and/or MLP projection matrices
        if not SKIP_ATTN:
            lm_model.layers[layer_index].self_attn.o_proj.weight = modify_tensor(
                lm_model.layers[layer_index].self_attn.o_proj.weight.data.to(torch.float32), householder_matrix
            )
        if not SKIP_MLP:
            lm_model.layers[layer_index].mlp.down_proj.weight = modify_tensor(
                lm_model.layers[layer_index].mlp.down_proj.weight.data.to(torch.float32), householder_matrix
            )

    bar_tensors.close()

    # Save the modified model and original tokenizer
    print("Saving modified model (with original tokenizer)...")
    model.save_pretrained(output_path)
    tokenizer.save_pretrained(output_path)

if __name__ == "__main__":
    parser = argparse.ArgumentParser(description="Modify and save a model based on baseline, desired and undesired instructions.")
    parser.add_argument("--model_id", type=str, required=True, help="The model ID to load the pretrained model from.")
    parser.add_argument("--output_path", type=str, required=True, help="The path to save the modified model and tokenizer.")

    args = parser.parse_args()
    main(args.model_id, args.output_path)

IMPORTANT: The file prompts.txt must contain a mix of normal instruction prompts and story prompts. I used these:

and concatenated them as they were similar in size.

You also need to be quite careful with the 3 different prompts: they should be "distributionaly similar" (in length, wording, etc) and try to only elicit the change in behaviour you want to detect.


To tell if it is working (or doing anything), look at the cosine similarity outputs. This is for the miqu-1-70b model:

Layer 1/80:
- Cosine similarity between desired_direction and undesired_direction: 0.74560546875

Layer 2/80:
- Cosine similarity between desired_direction and undesired_direction: 0.734375
- Cosine similarity between current householder_vector and previous householder_vector: 0.59033203125

Layer 3/80:
- Cosine similarity between desired_direction and undesired_direction: 0.78759765625
- Cosine similarity between current householder_vector and previous householder_vector: 0.587890625

Layer 4/80:
- Cosine similarity between desired_direction and undesired_direction: 0.70947265625
- Cosine similarity between current householder_vector and previous householder_vector: 0.60791015625

Layer 5/80:
- Cosine similarity between desired_direction and undesired_direction: 0.68017578125
- Cosine similarity between current householder_vector and previous householder_vector: 0.75390625

Layer 6/80:
- Cosine similarity between desired_direction and undesired_direction: 0.73388671875
- Cosine similarity between current householder_vector and previous householder_vector: 0.744140625

Layer 7/80:
- Cosine similarity between desired_direction and undesired_direction: 0.70751953125
- Cosine similarity between current householder_vector and previous householder_vector: 0.83251953125

Layer 8/80:
- Cosine similarity between desired_direction and undesired_direction: 0.67529296875
- Cosine similarity between current householder_vector and previous householder_vector: 0.6748046875

Layer 9/80:
- Cosine similarity between desired_direction and undesired_direction: 0.64892578125
- Cosine similarity between current householder_vector and previous householder_vector: 0.76171875

Layer 10/80:
- Cosine similarity between desired_direction and undesired_direction: 0.607421875
- Cosine similarity between current householder_vector and previous householder_vector: 0.67724609375

Layer 11/80:
- Cosine similarity between desired_direction and undesired_direction: 0.59228515625
- Cosine similarity between current householder_vector and previous householder_vector: 0.732421875

Layer 12/80:
- Cosine similarity between desired_direction and undesired_direction: 0.61328125
- Cosine similarity between current householder_vector and previous householder_vector: 0.80078125

Layer 13/80:
- Cosine similarity between desired_direction and undesired_direction: 0.6240234375
- Cosine similarity between current householder_vector and previous householder_vector: 0.884765625

Layer 14/80:
- Cosine similarity between desired_direction and undesired_direction: 0.59228515625
- Cosine similarity between current householder_vector and previous householder_vector: 0.75830078125

Layer 15/80:
- Cosine similarity between desired_direction and undesired_direction: 0.607421875
- Cosine similarity between current householder_vector and previous householder_vector: 0.7763671875

Layer 16/80:
- Cosine similarity between desired_direction and undesired_direction: 0.55615234375
- Cosine similarity between current householder_vector and previous householder_vector: 0.716796875

Layer 17/80:
- Cosine similarity between desired_direction and undesired_direction: 0.5732421875
- Cosine similarity between current householder_vector and previous householder_vector: 0.8095703125

Layer 18/80:
- Cosine similarity between desired_direction and undesired_direction: 0.51220703125
- Cosine similarity between current householder_vector and previous householder_vector: 0.71044921875

Layer 19/80:
- Cosine similarity between desired_direction and undesired_direction: 0.53369140625
- Cosine similarity between current householder_vector and previous householder_vector: 0.744140625

Layer 20/80:
- Cosine similarity between desired_direction and undesired_direction: 0.2083740234375
- Cosine similarity between current householder_vector and previous householder_vector: 0.548828125

Layer 21/80:
- Cosine similarity between desired_direction and undesired_direction: 0.28125
- Cosine similarity between current householder_vector and previous householder_vector: 0.79296875

Layer 22/80:
- Cosine similarity between desired_direction and undesired_direction: 0.302734375
- Cosine similarity between current householder_vector and previous householder_vector: 0.8515625

Layer 23/80:
- Cosine similarity between desired_direction and undesired_direction: 0.325439453125
- Cosine similarity between current householder_vector and previous householder_vector: 0.91748046875

Layer 24/80:
- Cosine similarity between desired_direction and undesired_direction: 0.333251953125
- Cosine similarity between current householder_vector and previous householder_vector: 0.89990234375

Layer 25/80:
- Cosine similarity between desired_direction and undesired_direction: 0.328857421875
- Cosine similarity between current householder_vector and previous householder_vector: 0.9248046875

Layer 26/80:
- Cosine similarity between desired_direction and undesired_direction: 0.34423828125
- Cosine similarity between current householder_vector and previous householder_vector: 0.9228515625

Layer 27/80:
- Cosine similarity between desired_direction and undesired_direction: 0.35498046875
- Cosine similarity between current householder_vector and previous householder_vector: 0.923828125

Layer 28/80:
- Cosine similarity between desired_direction and undesired_direction: 0.3955078125
- Cosine similarity between current householder_vector and previous householder_vector: 0.9052734375

Layer 29/80:
- Cosine similarity between desired_direction and undesired_direction: 0.431640625
- Cosine similarity between current householder_vector and previous householder_vector: 0.921875

Layer 30/80:
- Cosine similarity between desired_direction and undesired_direction: 0.45166015625
- Cosine similarity between current householder_vector and previous householder_vector: 0.92431640625

Layer 31/80:
- Cosine similarity between desired_direction and undesired_direction: 0.317626953125
- Cosine similarity between current householder_vector and previous householder_vector: 0.7958984375

Layer 32/80:
- Cosine similarity between desired_direction and undesired_direction: 0.372802734375
- Cosine similarity between current householder_vector and previous householder_vector: 0.9189453125

Layer 33/80:
- Cosine similarity between desired_direction and undesired_direction: 0.3916015625
- Cosine similarity between current householder_vector and previous householder_vector: 0.93408203125

Layer 34/80:
- Cosine similarity between desired_direction and undesired_direction: 0.40185546875
- Cosine similarity between current householder_vector and previous householder_vector: 0.947265625

Layer 35/80:
- Cosine similarity between desired_direction and undesired_direction: 0.395751953125
- Cosine similarity between current householder_vector and previous householder_vector: 0.95263671875

Layer 36/80:
- Cosine similarity between desired_direction and undesired_direction: 0.39990234375
- Cosine similarity between current householder_vector and previous householder_vector: 0.94970703125

Layer 37/80:
- Cosine similarity between desired_direction and undesired_direction: 0.414794921875
- Cosine similarity between current householder_vector and previous householder_vector: 0.95751953125

Layer 38/80:
- Cosine similarity between desired_direction and undesired_direction: 0.4140625
- Cosine similarity between current householder_vector and previous householder_vector: 0.94970703125

Layer 39/80:
- Cosine similarity between desired_direction and undesired_direction: 0.337158203125
- Cosine similarity between current householder_vector and previous householder_vector: 0.8935546875

Layer 40/80:
- Cosine similarity between desired_direction and undesired_direction: 0.314697265625
- Cosine similarity between current householder_vector and previous householder_vector: 0.91015625

Layer 41/80:
- Cosine similarity between desired_direction and undesired_direction: 0.2041015625
- Cosine similarity between current householder_vector and previous householder_vector: 0.919921875

Layer 42/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1893310546875
- Cosine similarity between current householder_vector and previous householder_vector: 0.9365234375

Layer 43/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1800537109375
- Cosine similarity between current householder_vector and previous householder_vector: 0.9384765625

Layer 44/80:
- Cosine similarity between desired_direction and undesired_direction: 0.18896484375
- Cosine similarity between current householder_vector and previous householder_vector: 0.96826171875

Layer 45/80:
- Cosine similarity between desired_direction and undesired_direction: 0.12744140625
- Cosine similarity between current householder_vector and previous householder_vector: 0.92822265625

Layer 46/80:
- Cosine similarity between desired_direction and undesired_direction: 0.11383056640625
- Cosine similarity between current householder_vector and previous householder_vector: 0.96484375

Layer 47/80:
- Cosine similarity between desired_direction and undesired_direction: 0.09356689453125
- Cosine similarity between current householder_vector and previous householder_vector: 0.96826171875

Layer 48/80:
- Cosine similarity between desired_direction and undesired_direction: 0.07769775390625
- Cosine similarity between current householder_vector and previous householder_vector: 0.96826171875

Layer 49/80:
- Cosine similarity between desired_direction and undesired_direction: 0.09283447265625
- Cosine similarity between current householder_vector and previous householder_vector: 0.974609375

Layer 50/80:
- Cosine similarity between desired_direction and undesired_direction: 0.09637451171875
- Cosine similarity between current householder_vector and previous householder_vector: 0.9794921875

Layer 51/80:
- Cosine similarity between desired_direction and undesired_direction: 0.09124755859375
- Cosine similarity between current householder_vector and previous householder_vector: 0.9775390625

Layer 52/80:
- Cosine similarity between desired_direction and undesired_direction: 0.0972900390625
- Cosine similarity between current householder_vector and previous householder_vector: 0.98095703125

Layer 53/80:
- Cosine similarity between desired_direction and undesired_direction: 0.09600830078125
- Cosine similarity between current householder_vector and previous householder_vector: 0.97998046875

Layer 54/80:
- Cosine similarity between desired_direction and undesired_direction: 0.09716796875
- Cosine similarity between current householder_vector and previous householder_vector: 0.9833984375

Layer 55/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1033935546875
- Cosine similarity between current householder_vector and previous householder_vector: 0.9814453125

Layer 56/80:
- Cosine similarity between desired_direction and undesired_direction: 0.10626220703125
- Cosine similarity between current householder_vector and previous householder_vector: 0.9814453125

Layer 57/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1055908203125
- Cosine similarity between current householder_vector and previous householder_vector: 0.984375

Layer 58/80:
- Cosine similarity between desired_direction and undesired_direction: 0.11114501953125
- Cosine similarity between current householder_vector and previous householder_vector: 0.98486328125

Layer 59/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1123046875
- Cosine similarity between current householder_vector and previous householder_vector: 0.9853515625

Layer 60/80:
- Cosine similarity between desired_direction and undesired_direction: 0.11456298828125
- Cosine similarity between current householder_vector and previous householder_vector: 0.9833984375

Layer 61/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1129150390625
- Cosine similarity between current householder_vector and previous householder_vector: 0.984375

Layer 62/80:
- Cosine similarity between desired_direction and undesired_direction: 0.11541748046875
- Cosine similarity between current householder_vector and previous householder_vector: 0.98681640625

Layer 63/80:
- Cosine similarity between desired_direction and undesired_direction: 0.134521484375
- Cosine similarity between current householder_vector and previous householder_vector: 0.98583984375

Layer 64/80:
- Cosine similarity between desired_direction and undesired_direction: 0.138671875
- Cosine similarity between current householder_vector and previous householder_vector: 0.986328125

Layer 65/80:
- Cosine similarity between desired_direction and undesired_direction: 0.140869140625
- Cosine similarity between current householder_vector and previous householder_vector: 0.986328125

Layer 66/80:
- Cosine similarity between desired_direction and undesired_direction: 0.138427734375
- Cosine similarity between current householder_vector and previous householder_vector: 0.98681640625

Layer 67/80:
- Cosine similarity between desired_direction and undesired_direction: 0.13525390625
- Cosine similarity between current householder_vector and previous householder_vector: 0.98486328125

Layer 68/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1385498046875
- Cosine similarity between current householder_vector and previous householder_vector: 0.9853515625

Layer 69/80:
- Cosine similarity between desired_direction and undesired_direction: 0.138671875
- Cosine similarity between current householder_vector and previous householder_vector: 0.984375

Layer 70/80:
- Cosine similarity between desired_direction and undesired_direction: 0.146240234375
- Cosine similarity between current householder_vector and previous householder_vector: 0.982421875

Layer 71/80:
- Cosine similarity between desired_direction and undesired_direction: 0.145263671875
- Cosine similarity between current householder_vector and previous householder_vector: 0.97998046875

Layer 72/80:
- Cosine similarity between desired_direction and undesired_direction: 0.145751953125
- Cosine similarity between current householder_vector and previous householder_vector: 0.9833984375

Layer 73/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1544189453125
- Cosine similarity between current householder_vector and previous householder_vector: 0.98193359375

Layer 74/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1600341796875
- Cosine similarity between current householder_vector and previous householder_vector: 0.98046875

Layer 75/80:
- Cosine similarity between desired_direction and undesired_direction: 0.161376953125
- Cosine similarity between current householder_vector and previous householder_vector: 0.97900390625

Layer 76/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1669921875
- Cosine similarity between current householder_vector and previous householder_vector: 0.978515625

Layer 77/80:
- Cosine similarity between desired_direction and undesired_direction: 0.18212890625
- Cosine similarity between current householder_vector and previous householder_vector: 0.9697265625

Layer 78/80:
- Cosine similarity between desired_direction and undesired_direction: 0.1895751953125
- Cosine similarity between current householder_vector and previous householder_vector: 0.97607421875

Layer 79/80:
- Cosine similarity between desired_direction and undesired_direction: 0.197509765625
- Cosine similarity between current householder_vector and previous householder_vector: 0.9716796875

Layer 80/80:
- Cosine similarity between desired_direction and undesired_direction: 0.18994140625
- Cosine similarity between current householder_vector and previous householder_vector: 0.962890625

You can clearly see here it is discriminating between the two types of stories. Not only that but it has identified layer 20 (and to a lesser extent, layer 31) as being the main "positivity inducing" layer(s) - this is interesting as these 2 layers are likely disrupted with the most common 20 and 16 block interleave "frankenmerges" applied to llama-2-70b fine-tunes (of which miqu-1-70b is one).


I'm still testing to see if MEAN_EXTRA_TOKENS_TO_GENERATE is needed: the above test was calculated with MEAN_EXTRA_TOKENS_TO_GENERATE=0 and NUM_PROMPT_SAMPLES=128.

I'm also not sure yet if you need to skip layers using SKIP_BEGIN_LAYERS and SKIP_END_LAYERS, and again the above did not use these (it turns out the nan for layer 1 was caused by the hidden_state at index [0] being before the first layer - I fixed this in the new code).

The reason I chose to only alter the mlp.down_proj.weight tensors and not the self_attn.o_proj.weight tensors:

  1. From my previous Mergekit experiments; altering self_attn.o_proj.weight tends to just screw up the models (by likely changing the distribution of the input to the MLP).
  2. Now the layer indices match up I am actually using the sample for the specific layer that was right after the mlp.down_proj.weight rather than the sample taken at the 60th-percentime (= 48th layer), as these samples are a direct output from the mlp.down_proj.weight projection.
jukofyork commented 5 months ago

Explanation is here: https://github.com/FailSpy/abliterator/issues/10#issuecomment-2156659963