aws / sagemaker-training-toolkit

Train machine learning models within a 🐳 Docker container using 🧠 Amazon SageMaker.
Apache License 2.0
496 stars 118 forks source link

SageMaker training toolkit reorders hyperparameters #221

Open vsimkus opened 2 months ago

vsimkus commented 2 months ago

Running a ScriptMode job on SageMaker modifies the order of the provided arguments.

For example, say my script train.py takes two arguments --config and --batch_size. If the hyperparameter argument on the Estimator class is set to {'config': exp_config_path, 'batch_size': 10}, then I would expect SageMaker training toolkit to invoke the script as follows: python train.py --config exp_config_path --batch_size 10 However, the toolkit sorts the hyperparameters alphanumerically first before invoking the script, hence resulting in the following invocation: python train.py --batch_size 10 --config exp_config_path

This happens because of a single line: https://github.com/aws/sagemaker-training-toolkit/blob/628166c157751ae2a46fddc11a7a8cac765fb22c/src/sagemaker_training/mapping.py#L78

The issue with this reordering is because sometimes the order of the arguments matters. For example, when using jsonargparse the order of invocations can result in two different argument settings. If the order is config>batch_size in script invocation, then batch_size argument is first loaded from the config file and subsequently is overridden by the command line argument --batch_size 10. On the other hand, if the order is batch_size>config, then the batch_size argument value is taken from the config.

I don't really see a reason why the hyperparameters should be sorted, so I think it would be safe to remove this sorting. This would be easy to fix by removing sorting in https://github.com/aws/sagemaker-training-toolkit/blob/628166c157751ae2a46fddc11a7a8cac765fb22c/src/sagemaker_training/mapping.py#L78?plain=1

Reproducing the bug

Here's a sketch to reproduce the bug:

from sagemaker.pytorch import PyTorch

# Setup the AWS role and session as usual 

estimator = PyTorch(role=role,
  entry_point='train.py',
  instance_count=1,
  instance_type="local",
  sagemaker_session=local_sess,
  output_path=output_path,
  hyperparameters={'config': 'configs/config.yaml', 'batch_size': 10,},
)

estimator.fit()

Running this code will result in the invocation with reordered hyperparameters, which can be seen from the logs:

l12v045866-algo-1-09ky7  | Invoking script with the following command:
l12v045866-algo-1-09ky7  | 
l12v045866-algo-1-09ky7  | /root/miniconda3/envs/conda_env/bin/python3.10 train.py --batch_size 10 --config configs/config.yaml

Instead I would expect the training toolkit to invoke the following command:

/root/miniconda3/envs/conda_env/bin/python3.10 train.py --config configs/config.yaml --batch_size 10