how to visual outier of activation #97

Open harleyszhang opened 3 weeks ago

harleyszhang commented 3 weeks ago

this is my visual result of activation、weight、smootha_ctivation、smooth_weight. why i din't see the phenomenon that activation outliers only exist in certain channels!


and my code is there, run llama-2-7b model to visual activation and weight

import torch
import numpy as np
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
from transformers import AutoModelForCausalLM, LlamaTokenizer

# 加载模型和分词器
def load_model_and_tokenizer(model_name, device):
    tokenizer = LlamaTokenizer.from_pretrained(model_name)
    tokenizer.pad_token = tokenizer.eos_token  # 设置 eos_token 为 pad_token
    model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16).to(device)
    return model, tokenizer

# 获取激活值和权重
def get_activations_and_weights(model, tokenizer, texts, device):
    inputs = tokenizer(texts, return_tensors="pt", padding=True, truncation=True).to(device)
    outputs = model(**inputs, output_hidden_states=True)
    print("outputs.hidden_states shape",outputs.hidden_states.shape)

    activation = outputs.hidden_states[4].abs()  # 使用第四层激活值
    weight = model.model.layers[0].self_attn.q_proj.weight.abs()  # 第一层权重
    return activation, weight

# 计算 SmoothQuant 缩放因子
def calculate_scales(activation, weight, alpha=0.5):
    act_max = torch.amax(activation.view(-1, activation.size(-1)), dim=0).clamp(min=1e-5)
    w_max = torch.amax(weight, dim=0).clamp(min=1e-5)
    scales = act_max.pow(alpha) / w_max.pow(1 - alpha)
    return scales

# 应用 SmoothQuant 缩放因子到激活值和权重
def apply_smoothquant_scaling(activation, weight, scales):
    smooth_activation = activation / scales.view(1, 1, -1)
    smooth_weight = weight * scales.view(1, -1)
    return smooth_activation, smooth_weight

# 检测离群值并打印通道索引
def find_outlier_channels(activation, threshold=20):
    mean = activation.mean(dim=(0, 1))
    std = activation.std(dim=(0, 1))
    z_scores = (activation - mean) / std
    outliers = torch.where(z_scores > threshold)
    unique_channels = torch.unique(outliers[2])
    print(f"离群值所在的通道索引: {unique_channels.tolist()}")

# 3D 绘图函数
def plot_3d(data, title, xlabel, ylabel, zlabel, color, ax, y_max):
    x, y = np.meshgrid(np.arange(data.shape[1]), np.arange(data.shape[0]))
    x, y = x.flatten(), y.flatten()
    z = np.zeros_like(x)
    dx = dy = 1
    dz = data.flatten()
    ax.bar3d(x, y, z, dx, dy, dz, color=color, zsort='average')
    ax.set_zlim(0, y_max)

# 主函数
def main():
    model_name = "./llm-awq/hf_weight/llama-2-7b/"
    device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
    model, tokenizer = load_model_and_tokenizer(model_name, device)

    # 处理输入文本
    input_texts = [
        "The quick brown fox jumps over the lazy dog.",
        "Artificial intelligence is revolutionizing the world.",
        "Large language models are powerful tools for NLP tasks."
    activation, weight = get_activations_and_weights(model, tokenizer, input_texts, device)

    # 检查离群值所在通道

    # 计算 SmoothQuant 缩放因子并应用平滑转换
    scales = calculate_scales(activation, weight)
    smooth_activation, smooth_weight = apply_smoothquant_scaling(activation, weight, scales)

    # 统一 y 轴范围
    y_max = max(

    # 绘图
    fig = plt.figure(figsize=(18, 8))
    plot_titles = [
        ("Activation (Original)", activation[0], "brown"),
        ("Activation (SmoothQuant)", smooth_activation[0], "blue"),
        ("Weight (Original)", weight, "blue"),
        ("Weight (SmoothQuant)", smooth_weight, "blue")

    for i, (title, data, color) in enumerate(plot_titles, start=1):
        ax = fig.add_subplot(1, 4, i, projection='3d')
        xlabel, ylabel = ("Channel", "Token") if "Activation" in title else ("In Channel", "Out Channel")
        plot_3d(data.cpu().numpy(), title, xlabel, ylabel, "Absolute Value", color, ax, y_max)

    fig.suptitle("SmoothQuant Visualization", fontsize=16)
    plt.savefig("llama2_7b_smoothquant_visualization.png", format='png', dpi=300)

if __name__ == "__main__":

can you tell me what is my problem, and what is your visual code