Closed claralp closed 7 months ago
cc also @kashif
@claralp depending on the batch-size it could be some of the metrics are nan, this should not effect the training etc. and special attention has been paid to make sure the loss etc. is robust to these nans when doing back-prop.
@claralp i do not think nans in a dict should cause this to crash... do you have some crash back-traces?
@kashif there are no errors or warnings in the stdout/stderr, it just stops at some point after the nan rewards appear, so I cannot provide a stack trace here.
However, the Azure execution wrapper log shows a blocking process:
2024-03-19T03:33:30.165457Z ERROR Execution::wait_for_completion{parent_span=Span { name: "Execution::spawn", level: Level(Info), target: "execution_wrapper::execution", id: Id(6755674318962691), module_path: "execution_wrapper::execution", line: 163, file: "executor/execution-wrapper/src/execution/mod.rs" } process_manager=Mutex { data: ProcessManager { dangling_processes: [], user_process_groups: [34] } }}: execution_wrapper::execution::process_manager: Failed blocking user process detected, process name: echo, process pid: 34, code: None success_return_code=Zero { additional_codes: [] } code=None
2024-03-19T03:33:31.167084Z ERROR Execution::wait_for_completion{parent_span=Span { name: "Execution::spawn", level: Level(Info), target: "execution_wrapper::execution", id: Id(6755674318962691), module_path: "execution_wrapper::execution", line: 163, file: "executor/execution-wrapper/src/execution/mod.rs" } process_manager=Mutex { data: ProcessManager { dangling_processes: [], user_process_groups: [34] } }}: execution_wrapper::execution: Execution process terminated by a signal, which may be due to failure in other user processes on the same node or node ran out of memory. local_rank=0 name=echo
lifecycler log shows only a Preemption signal:
2024-03-19T03:33:29.494161Z WARN run_lifecycler:run_service_and_step_through_lifecycle:step_through_lifecycle: lifecycler::lifecycle: Received abort message, exiting lifecycle abort_message=AbortMessage { error: Some(Error { code: "ReceivedPreemptionSignal", message: "{\"Compliant\":\"Job was terminated due to: Runtime received a preemption signal.\"}", target: "", node_info: None, category: UserError, error_details: [], inner_error: None }), broadcast_abort: true, request_timeout: 25 }
I think this is could be the "normal" low-prioity Azure preemption? :-(
Important note here: The crash only appears after the training shows nan values. Otherwise it doesn't.
I even saw cases where all results converge to nan values
{'loss': 0.0, 'grad_norm': 281.6248474121094, 'learning_rate': 9.856115107913668e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.17875319719314575, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.11}
{'loss': 0.0, 'grad_norm': 192.08326721191406, 'learning_rate': 9.848121502797762e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 1.0570355653762817, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.11}
{'loss': 0.0, 'grad_norm': 33.55568313598633, 'learning_rate': 9.840127897681853e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 2.1016669273376465, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.12}
{'loss': 0.0, 'grad_norm': 44.5154914855957, 'learning_rate': 9.832134292565947e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 2.197722911834717, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.12}
{'loss': 0.0, 'grad_norm': 10.592936515808105, 'learning_rate': 9.82414068745004e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 1.0713751316070557, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.13}
{'loss': 0.0, 'grad_norm': 61.1552734375, 'learning_rate': 9.81614708233413e-07, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.3863883912563324, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.13}
Could there be anything wrong with the hyperparameter choice, @kashif ?
@claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans?
also does this happen if you try locally outside of the azure
The output below is from a test with very unbalanced data, namely 2k desired completions and 10k undesired ones.
I know that a ratio between 4:3 and 1:1 is required for proper training.
This is just an experiment to see if missing pos/neg samples in a batch might be the reason behind nan values as rewards.
But here I get nan losses even without nan rewards...
{'loss': 1.0431, 'grad_norm': 42.099464416503906, 'learning_rate': 1.0000000000000002e-06, 'rewards/chosen': 0.0, 'rewards/rejected': 0.0, 'rewards/margins': 0.0, 'kl': 0.0, 'logps/chosen': -37.16696548461914, 'logps/rejected': -87.62107849121094, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 41.9438362121582, 'learning_rate': 2.0000000000000003e-06, 'rewards/chosen': 0.0, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.0, 'logps/chosen': -32.92508316040039, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 29.28327178955078, 'learning_rate': 3e-06, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.15479230880737305, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 32.70748519897461, 'learning_rate': 4.000000000000001e-06, 'rewards/chosen': 0.06518054008483887, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.43951892852783203, 'logps/chosen': -31.101844787597656, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 44.989227294921875, 'learning_rate': 5e-06, 'rewards/chosen': 0.3087962865829468, 'rewards/rejected': 0.23543643951416016, 'rewards/margins': 0.07335984706878662, 'kl': 1.230994462966919, 'logps/chosen': -32.83413314819336, 'logps/rejected': -74.81724548339844, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 55.32667541503906, 'learning_rate': 6e-06, 'rewards/chosen': 0.3336696922779083, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 0.3016533851623535, 'logps/chosen': -38.598453521728516, 'logps/rejected': nan, 'epoch': 0.0}
{'loss': nan, 'grad_norm': 32.44403839111328, 'learning_rate': 7e-06, 'rewards/chosen': 0.8524215817451477, 'rewards/rejected': 0.5893988609313965, 'rewards/margins': 0.2630227208137512, 'kl': 0.7648882865905762, 'logps/chosen': -35.86614227294922, 'logps/rejected': -93.13447570800781, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 26.85154914855957, 'learning_rate': 8.000000000000001e-06, 'rewards/chosen': 0.8056153059005737, 'rewards/rejected': 0.40718716382980347, 'rewards/margins': 0.39842814207077026, 'kl': 1.3891675472259521, 'logps/chosen': -34.07681655883789, 'logps/rejected': -113.53411102294922, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 25.181703567504883, 'learning_rate': 9e-06, 'rewards/chosen': nan, 'rewards/rejected': 0.9289813041687012, 'rewards/margins': nan, 'kl': 1.279036521911621, 'logps/chosen': nan, 'logps/rejected': -132.0060272216797, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 36.62141799926758, 'learning_rate': 1e-05, 'rewards/chosen': 1.4094278812408447, 'rewards/rejected': 0.8396401405334473, 'rewards/margins': 0.5697878003120422, 'kl': 2.0255985260009766, 'logps/chosen': -30.87615394592285, 'logps/rejected': -102.92286682128906, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 43.035221099853516, 'learning_rate': 9.997300215982722e-06, 'rewards/chosen': 1.5928469896316528, 'rewards/rejected': 1.5922844409942627, 'rewards/margins': 0.0005625784397125244, 'kl': 2.884922981262207, 'logps/chosen': -39.46299362182617, 'logps/rejected': -121.78970336914062, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 33.07608413696289, 'learning_rate': 9.994600431965443e-06, 'rewards/chosen': nan, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 3.1301448345184326, 'logps/chosen': nan, 'logps/rejected': nan, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 43.48128128051758, 'learning_rate': 9.991900647948165e-06, 'rewards/chosen': 2.113973617553711, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 3.475428819656372, 'logps/chosen': -26.679065704345703, 'logps/rejected': nan, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 31.501819610595703, 'learning_rate': 9.989200863930886e-06, 'rewards/chosen': 2.6266024112701416, 'rewards/rejected': 2.2295963764190674, 'rewards/margins': 0.3970060348510742, 'kl': 4.643209934234619, 'logps/chosen': -42.25154495239258, 'logps/rejected': -95.91471862792969, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 34.09553527832031, 'learning_rate': 9.986501079913607e-06, 'rewards/chosen': 2.7660703659057617, 'rewards/rejected': 2.6509010791778564, 'rewards/margins': 0.11516910791397095, 'kl': 4.8384199142456055, 'logps/chosen': -49.93422317504883, 'logps/rejected': -73.00190734863281, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 30.591957092285156, 'learning_rate': 9.983801295896329e-06, 'rewards/chosen': 3.131122350692749, 'rewards/rejected': 2.9620559215545654, 'rewards/margins': 0.1690664291381836, 'kl': 4.498130798339844, 'logps/chosen': -29.836196899414062, 'logps/rejected': -105.75230407714844, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 13.737163543701172, 'learning_rate': 9.98110151187905e-06, 'rewards/chosen': nan, 'rewards/rejected': 3.1204824447631836, 'rewards/margins': nan, 'kl': 6.049262523651123, 'logps/chosen': nan, 'logps/rejected': -96.40724182128906, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 30.375396728515625, 'learning_rate': 9.978401727861771e-06, 'rewards/chosen': nan, 'rewards/rejected': 3.636046886444092, 'rewards/margins': nan, 'kl': 6.3599958419799805, 'logps/chosen': nan, 'logps/rejected': -97.00442504882812, 'epoch': 0.01}
{'loss': nan, 'grad_norm': 27.26076889038086, 'learning_rate': 9.975701943844493e-06, 'rewards/chosen': 4.384129524230957, 'rewards/rejected': 3.9822707176208496, 'rewards/margins': 0.40185898542404175, 'kl': 8.23063850402832, 'logps/chosen': -24.248661041259766, 'logps/rejected': -105.89572143554688, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 18.513507843017578, 'learning_rate': 9.973002159827214e-06, 'rewards/chosen': 4.265963077545166, 'rewards/rejected': 3.8863425254821777, 'rewards/margins': 0.3796207308769226, 'kl': 6.635190010070801, 'logps/chosen': -24.802963256835938, 'logps/rejected': -68.99553680419922, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 25.997692108154297, 'learning_rate': 9.970302375809935e-06, 'rewards/chosen': 5.037494659423828, 'rewards/rejected': 4.227317810058594, 'rewards/margins': 0.8101770877838135, 'kl': 8.07493782043457, 'logps/chosen': -24.345657348632812, 'logps/rejected': -74.88150024414062, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 26.245861053466797, 'learning_rate': 9.967602591792658e-06, 'rewards/chosen': 4.526309490203857, 'rewards/rejected': 4.603299140930176, 'rewards/margins': -0.07698965072631836, 'kl': 8.698637008666992, 'logps/chosen': -22.94290542602539, 'logps/rejected': -99.22356414794922, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 22.14063835144043, 'learning_rate': 9.964902807775378e-06, 'rewards/chosen': 5.355809211730957, 'rewards/rejected': 4.891297340393066, 'rewards/margins': 0.464511513710022, 'kl': 8.954204559326172, 'logps/chosen': -23.850910186767578, 'logps/rejected': -87.7445068359375, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 25.642059326171875, 'learning_rate': 9.962203023758101e-06, 'rewards/chosen': 5.606294631958008, 'rewards/rejected': 6.807004928588867, 'rewards/margins': -1.2007099390029907, 'kl': 9.733396530151367, 'logps/chosen': -24.039264678955078, 'logps/rejected': -119.2092514038086, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 10.412492752075195, 'learning_rate': 9.959503239740822e-06, 'rewards/chosen': 5.953470230102539, 'rewards/rejected': 5.025949954986572, 'rewards/margins': 0.9275206327438354, 'kl': 10.74533462524414, 'logps/chosen': -16.727996826171875, 'logps/rejected': -80.9796142578125, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 17.695709228515625, 'learning_rate': 9.956803455723542e-06, 'rewards/chosen': nan, 'rewards/rejected': 6.109594345092773, 'rewards/margins': nan, 'kl': 11.900070190429688, 'logps/chosen': nan, 'logps/rejected': -121.30842590332031, 'epoch': 0.02}
{'loss': nan, 'grad_norm': 35.035892486572266, 'learning_rate': 9.954103671706265e-06, 'rewards/chosen': 6.687896251678467, 'rewards/rejected': nan, 'rewards/margins': nan, 'kl': 12.4317626953125, 'logps/chosen': -16.511333465576172, 'logps/rejected': nan, 'epoch': 0.02}
kashif commented 1 hour ago @claralp so the main hyperparam that could affect this is the batch size as it needs a good mix of good and bad examples, as well as for the KL estimates... your learning rate is tiny so that should be good... what is your batch size when you get all nans?
batch size is 8 and gradient accumulation steps is 2 as in the config above
also does this happen if you try locally outside of the azure
currently checking this
closed with #1499 and #1514
Within the training with KTO Trainer I occasionally experience
nan
values as rewards.I am running the training as a job on Ms Azure with one GPU (NVIDIA A100 80GB PCIe).
Ultimately these issues cause my Azure job to crash and retry...
The log output I get from the KTOTrainer:
my pip freeze:
the training script I use:
the call arguments
Maybe @lewtun can help