NVIDIA / NeMo-Aligner

Scalable toolkit for efficient model alignment
Apache License 2.0
419 stars 45 forks source link

Rejection sampling clean #218

Open abukharin3 opened 1 week ago

abukharin3 commented 1 week ago

What does this PR do ?

Adds the rejection sampling algorithm.

Changelog

Usage

read -r -d '' cmd_ppo <<EOF wandb login ${WANDB_API_KEY} \ && cd ${NEMO_RLHF_DIR} \ && export PYTHONPATH="${NEMO_RLHF_DIR}:${PYTHONPATH}" \ && export HYDRA_FULL_ERROR=1 \ && export CUDA_LAUNCH_BLOCKING=1 \ && export PYTRITON_HOME=/pytriton_cache \ && export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:512 \ && python -u examples/nlp/gpt/train_gpt_rs_actor.py \ --config-path=${CONF_DIR} \ --config-name=${CONFIG_NAME} \ "model.data.data_prefix={train: [${TRAIN_DATA_PATH}], validation: [${VALID_DATA_PATH}], test: [${VALID_DATA_PATH}]}" \ pretrained_checkpoint.restore_from_path=\"${ACTOR_NEMO_FILE}\" \ exp_manager.checkpoint_callback_params.save_top_k=1 \ exp_manager.explicit_log_dir=\"${ACTOR_LOG_DIR}\" \ exp_manager.create_wandb_logger=True \ exp_manager.wandb_logger_kwargs.name=\"${ACTOR_NAME}\" \ exp_manager.wandb_logger_kwargs.project=${WANDB_PROJECT} \ ++exp_manager.max_time_per_run=\"00:03:30:00\" \ trainer.rs.max_epochs=1 \ trainer.rs.max_steps=313 \ trainer.rs.val_check_interval=4 \ trainer.num_nodes=8 \ trainer.devices=8 \ ++model.tensor_model_parallel_size=4 \ model.global_batch_size=${ACTOR_GBS} \ model.micro_batch_size=1 \ model.optim.lr=\"\\${multiply:${ACTOR_LR},1.001}\" \ model.optim.sched.warmup_steps=0 \ model.optim.sched.constant_steps=312 \ model.optim.sched.min_lr=${ACTOR_LR} \ model.optim.weight_decay=0.01 \ model.rs.num_rollout_samples=${NUM_ROLLOUTS} \ model.rs.rollout_micro_batch_size=8 \ model.rs.forward_micro_batch_size=8 \ model.rs.val_rollout_micro_batch_size=8 \ model.data.data_impl=jsonl \ remote_critic_rm.reward_model.ip=${host_critic} \ remote_critic_rm.reward_model.port=${CRITIC_PORT} \ model.rs.num_rollout_per_prompt=4 \ model.rs.num_select=1 EOF

Before your PR is "Ready for review"

Pre checks:

Checklist when contributing a new algorithm

Additional Information