import torch
from transformers import AutoModelForSequenceClassification, AutoTokenizer
device = "cuda"
path = "RLHFlow/ArmoRM-Llama3-8B-v0.1"
model = AutoModelForSequenceClassification.from_pretrained(path, device_map=device,
trust_remote_code=True, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(path, use_fast=True)
# We load a random sample from the validation set of the HelpSteer dataset
prompt = 'What are some synonyms for the word "beautiful"?'
response = "Nicely, Beautifully, Handsome, Stunning, Wonderful, Gorgeous, Pretty, Stunning, Elegant"
messages = [{"role": "user", "content": prompt},
{"role": "assistant", "content": response}]
input_ids = tokenizer.apply_chat_template(messages, return_tensors="pt").to(device)
with torch.no_grad():
output = model(input_ids)
# Multi-objective rewards for the response
multi_obj_rewards = output.rewards.cpu().float()
# The gating layer's output is conditioned on the prompt
gating_output = output.gating_output.cpu().float()
# The preference score for the response, aggregated from the
# multi-objective rewards with the gating layer
preference_score = output.score.cpu().float()
# We apply a transformation matrix to the multi-objective rewards
# before multiplying with the gating layer's output. This mainly aims
# at reducing the verbosity bias of the original reward objectives
obj_transform = model.reward_transform_matrix.data.cpu().float()
# The final coefficients assigned to each reward objective
multi_obj_coeffs = gating_output @ obj_transform.T
# The preference score is the linear combination of the multi-objective rewards with
# the multi-objective coefficients, which can be verified by the following assertion
assert torch.isclose(torch.sum(multi_obj_rewards * multi_obj_coeffs, dim=1), preference_score, atol=1e-3)
# Find the top-K reward objectives with coefficients of the highest magnitude
K = 3
top_obj_dims = torch.argsort(torch.abs(multi_obj_coeffs), dim=1, descending=True,)[:, :K]
top_obj_coeffs = torch.gather(multi_obj_coeffs, dim=1, index=top_obj_dims)
# The attributes of the 19 reward objectives
attributes = ['helpsteer-helpfulness','helpsteer-correctness','helpsteer-coherence',
'helpsteer-complexity','helpsteer-verbosity','ultrafeedback-overall_score',
'ultrafeedback-instruction_following', 'ultrafeedback-truthfulness',
'ultrafeedback-honesty','ultrafeedback-helpfulness','beavertails-is_safe',
'prometheus-score','argilla-overall_quality','argilla-judge_lm','code-complexity',
'code-style','code-explanation','code-instruction-following','code-readability']
example_index = 0
for i in range(K):
attribute = attributes[top_obj_dims[example_index, i].item()]
coeff = top_obj_coeffs[example_index, i].item()
print(f"{attribute}: {round(coeff,5)}")
# code-complexity: 0.19922
# helpsteer-verbosity: -0.10864
# ultrafeedback-instruction_following: 0.07861
# The actual rewards of this example from the HelpSteer dataset
# are [3,3,4,2,2] for the five helpsteer objectives:
# helpfulness, correctness, coherence, complexity, verbosity
# We can linearly transform our predicted rewards to the
# original reward space to compare with the ground truth
helpsteer_rewards_pred = multi_obj_rewards[0, :5] * 5 - 0.5
print(helpsteer_rewards_pred)
# [2.78125 2.859375 3.484375 1.3847656 1.296875 ]
Modified Files
rewardbench/models/
__init__.py: add config for ArmoRM
armorm.py: add ArmoRM pipeline
scripts/
configs/eval_config.yaml: eval config for ArmoRM
run_rm.py:
Enable TF32 (to use TensorCore on Ampere GPUs)
Add a model config choice, torch_dtype. Our ArmoRM is native to torch.bfloat16 (certainlytorch.float32 also works, but it takes larger GPU memory), and the Int-8 quantization leads to a very slow inference of ArmoRM (even slower than using FP32). This new config choice allows evaluation of ArmoRM under torch.bfloat16.
This PR adds a new reward model, ArmoRM, to RewardBench.
Description: Arbitrary-Rating Multi-Objective Reward Model (ArmoRM) with Mixture-of-Experts (MoE) Aggregation of Reward Objectives
Authors (* indicates equal contribution)
Haoxiang Wang*, Wei Xiong*, Tengyang Xie, Han Zhao, Tong Zhang
Blog: To appear soon (with implementation details)
Tech Report: To be released in June 2024
Model: ArmoRM-Llama3-8B-v0.1
Code Repository: https://github.com/RLHFlow/RLHF-Reward-Modeling/
Architecture
RewardBench LeaderBoard
Demo Code
Modified Files
rewardbench/models/
__init__.py
: add config for ArmoRMarmorm.py
: add ArmoRM pipelinescripts/
configs/eval_config.yaml
: eval config for ArmoRMrun_rm.py
:torch_dtype
. Our ArmoRM is native totorch.bfloat16
(certainlytorch.float32
also works, but it takes larger GPU memory), and the Int-8 quantization leads to a very slow inference of ArmoRM (even slower than using FP32). This new config choice allows evaluation of ArmoRM undertorch.bfloat16
.Evaluation Commands