NVIDIA / NeMo-Aligner

Scalable toolkit for efficient model alignment
Apache License 2.0
628 stars 78 forks source link

feat: adds REINFORCE algorithm #357

Closed abukharin3 closed 5 days ago

abukharin3 commented 1 month ago

What does this PR do ?

Adds the REINFORCE algorithm.

Changelog

Usage

NAME="2p_reinforce"

PARAMETERS

RM_NEMO_FILE="/path/to/trained_rm.nemo"

ACTOR_NEMO_FILE="/path/to/sft_model.nemo"

TRAIN_DATA_PATH="/path/to/train_prompts.jsonl" VALID_DATA_PATH="/path/to/test_prompts.jsonl"

RESULTS_DIR="/path/to/results_dir" mkdir -p $RESULTS_DIR

GPFS="/path/to/nemo-aligner-repo" MOUNTS="--container-mounts=MOUNTS" # mounts

CONTAINER=<<>> # use the latest NeMo Training container, Aligner will work there

PROJECT=reinforce_run

CRITIC_LOG_DIR="${RESULTS_DIR}/critic_results" CRITIC_OUTFILE="${CRITIC_LOG_DIR}/criticoutput%j_%t.log" CRITIC_ERRFILE="${CRITIC_LOG_DIR}/criticerror%j_%t.err" REWARD_PORT=5567 CRITIC_CONFIG_PATH="${GPFS}/examples/nlp/gpt/conf" CRITIC_CONFIG_NAME="inference_rm"

CONF_DIR="${GPFS}/examples/nlp/gpt/conf" CONFIG_NAME="gpt_reinforce_actor"

mkdir -p $CRITIC_LOG_DIR

CRITIC_NAME="${NAME}_critic"

read -r -d '' cmd_critic_inference <<EOF cd ${GPFS} \ && export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ && export HYDRA_FULL_ERROR=1 \ && python -u examples/nlp/gpt/serve_reward_model.py \ --config-path=${CRITIC_CONFIG_PATH} \ --config-name=${CRITIC_CONFIG_NAME} \ trainer.num_nodes=1 \ trainer.devices=8 \ ++model.tensor_model_parallel_size=4 \ rm_model_file=${RM_NEMO_FILE} \ inference.port=${REWARD_PORT} EOF

srun --het-group=0 -o $CRITIC_OUTFILE -e $CRITIC_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_critic_inference}" &

sleep 30

ACTOR_LOG_DIR="${RESULTS_DIR}/actor_results" CHECKPOINT_DIR="${ACTOR_LOG_DIR}/checkpoints" TENSOBOARD_DIR="${ACTOR_LOG_DIR}/tensorboard"

NUM_ROLLOUTS=16 NORMALIZE="True" ACTOR_LR="1e-6" ACTOR_GBS=16 KL=0.01 USE_FLASK=False

mkdir -p $ACTOR_LOG_DIR mkdir -p $TENSOBOARD_DIR mkdir -p $CHECKPOINT_DIR

ACTOR_NAME="${NAME}_actor"

host_reward="$(scontrol show hostnames=$SLURM_JOB_NODELIST_HET_GROUP_0 | head -n1)"

read -r -d '' cmd_reinforce <<EOF cd ${GPFS} \ export PYTHONPATH="${GPFS}:${PYTHONPATH}" \ && export HYDRA_FULL_ERROR=1 \ && python -u examples/nlp/gpt/train_gpt_reinforce_actor.py \ "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ pretrained_checkpoint.restore_from_path=\"${PRETRAINED_ACTOR_NEMO_FILE}\" \ exp_manager.checkpoint_callback_params.save_top_k=1 \ exp_manager.explicit_log_dir=\"${RESULTS_DIR}\" \ trainer.reinforce.max_epochs=1 \ trainer.reinforce.max_steps=313 \ trainer.reinforce.val_check_interval=4 \ trainer.num_nodes=1 \ trainer.devices=8 \ trainer.reinforce.trt_llm.enable=True \ trainer.reinforce.trt_llm.reshard=True \ trainer.reinforce.trt_llm.unload_engine_train=False \ ++model.tensor_model_parallel_size=4 \ ++model.reinforce.num_rollout_samples=${NUM_ROLLOUTS} \ model.global_batch_size=${ACTOR_GBS} \ model.micro_batch_size=1 \ model.optim.lr=\"\\${multiply:${ACTOR_LR},1.001}\" \ model.optim.sched.warmup_steps=0 \ model.optim.sched.constant_steps=312 \ model.optim.sched.min_lr=${ACTOR_LR} \ model.optim.weight_decay=0.01 \ model.reinforce.rollout_micro_batch_size=16 \ model.reinforce.forward_micro_batch_size=16 \ model.reinforce.val_rollout_micro_batch_size=8 \ model.data.data_impl=jsonl \ remote_rm.reward_model.ip=${host_reward} \ remote_rm.reward_model.port=${REWARD_PORT} \ ++model.reinforce.length_params.max_length=2048 \ trainer.reinforce.initial_policy_kl_penalty="${KL}" \ ++model.optim.bucket_cap_mb=200 \ ++model.dist_ckpt_format=zarr \ ++model.optim.overlap_grad_sync=False \ ++model.optim.contiguous_grad_buffer=True \ ++model.enable_nge=True \ trainer.reinforce.batch_iterator.use_flask=${USE_FLASK} \ trainer.reinforce.rollout_batch_seq_length=4096 EOF

srun --het-group=1 -o $PPO_OUTFILE -e $PPO_ERRFILE --container-image=${CONTAINER} $MOUNTS bash -c "${cmd_reinforce}" &

wait

Before your PR is "Ready for review"

Pre checks:

Checklist when contributing a new algorithm