sail-sg / sdft

[ACL 2024] The official codebase for the paper "Self-Distillation Bridges Distribution Gap in Language Model Fine-tuning".
https://aclanthology.org/2024.acl-long.58
99 stars 4 forks source link

parameter shift计算 #4

Closed 447428054 closed 5 months ago

447428054 commented 5 months ago

请问计算parameter shift的脚本在哪里呢

rickyang1114 commented 5 months ago

Hello, thanks for your interest in our work!

The complete code for calculating the parameter shift was not preserved after the execution of the ablation study. Nevertheless, I have retained some code snippets that might prove beneficial. Presented below is a snippet employed for the extraction of trainable parameters from both the seed LM and the model post-training:

output_dict = {}
target_path = "/path/to/save"
for name in model.state_dict().keys():
    if "lora" not in name: # Filter out parameters that are not changed during fine-tuning
        continue
    param = model.state_dict()[name]
    output_dict[name] = param.detach().cpu()
torch.save(output_dict, target_path) # save parameters to a file as a dict
exit()

To integrate this code snippet, please insert it at line 222 in the file src/llmtuner/model/loader.py within the reproduce branch. Subsequent to the parameter extraction, these parameters can be converted into a NumPy array and the parameter shift can be calculated easily.

Hope this helps!

447428054 commented 2 months ago

@rickyang1114 请问parameter shift计算的是参数之间的L2吗,下列代码是否可以完成: '''

def calParameterShift(modelpath1, modelpath2): model1 = AutoModelForCausalLM.from_pretrained(modelpath1) model2 = AutoModelForCausalLM.from_pretrained(modelpath2)

# 确保模型在评估模式
model1.eval()
model2.eval()

# 所有参数的L2范数和
l2_norm = 0.0

# 遍历模型的所有参数
for param1, param2 in zip(model1.parameters(), model2.parameters()):
    # 计算参数之间的差异
    difference = param1 - param2
    # 计算L2范数并累加
    l2_norm += torch.norm(difference).item() ** 2

# 取平方根
l2_norm = torch.sqrt(torch.tensor(l2_norm))

return l2_norm.item()

'''

rickyang1114 commented 2 months ago

是这样的。由于实验使用lora微调,所以只需要计算lora部分的L2即可。