aws / sagemaker-training-toolkit

Train machine learning models within a 🐳 Docker container using 🧠 Amazon SageMaker.
Apache License 2.0
496 stars 118 forks source link

Hyperparameters and other cmd arguments are not passed to shell entrypoint in tensorflow > 2.4 #115

Closed unoebauer closed 2 years ago

unoebauer commented 2 years ago

Describe the bug

When using a shell (i.e. COMMAND) entry point for the tensorflow training estimator, the command line arguments are not passed properly into the shell script when specifying framework_version > 2.4.

I've confirmed that the issue is absent in framework versions 2.1 - 2.4 but present in 2.4, 2.5, and 2.6.

After manually inspecting the running training containers in local mode, I think that I can traceback the issue to the create function in process.py. Until v3.9.2 subprocess.Popen() was used to spawn the process to execute the command calling the shell entrypoint. Afterwards, from v3.9.3 onwards asyncio.create_subprocess_shell is used instead and somehow doesn't pipe through the arguments to the executed shell entrypoint, even though the executed cmd remains the same.

To reproduce

The issue can be reproduced/investigated by using the following shell entrypoint and python launcher:

echo "entrypoint invoked with arguments $@"

echo "testing different cmd access methods":
echo '$@'
echo $@
echo '$*'
echo $*
echo '$1'
echo $1

PYTHONEXE=`which python`

echo "calling main python entrypoint"
echo "executing /usr/bin/python3 main_test.py $@"
# $PYTHONEXE main_test.py "$@"

sleep 3600
import logging
import os
import sys

import sagemaker
import sagemaker.tensorflow

logging.basicConfig(stream=sys.stdout, level=logging.INFO)
log = logging.getLogger(__name__)

def main():

    # dummy data location for testing
    data_location = "file:///tmp"
    source_dir = os.path.abspath(os.path.split(__file__)[0])

    role = "<your-sagemaker-execution-role>"

    """
    2.1.0 / py3 -> works
    2.2.0 / py37 -> works
    2.3, 2.3.0 / py37 -> works
    2.4, 2.4.1 / py37 -> works
    2.5, 2.5.1 / py37 -> NO cmd passing in sh entrypoint
    2.6 / py38 -> NO cmd passing in sh entrypoint
    """
    framework_version = "2.5"
    py_version = "py37"
    entry_point = "entrypoint_test.sh"

    instance_count = 1
    train_instance_type = "local"

    data_inputs = {"data": data_location}

    # some dummy hyperparameters to test cmd passing
    hyperparameters = {
        "test_arg1": "test_val_1",
        "test_arg2": "test_val_2",
        "test_arg3": "test_val_3",
    }

    estimator = sagemaker.tensorflow.TensorFlow(
        entry_point=entry_point,
        source_dir=source_dir,
        role=role,
        framework_version=framework_version,
        py_version=py_version,
        instance_count=instance_count,
        instance_type=train_instance_type,
        hyperparameters=hyperparameters,
    )

    estimator.fit(
        inputs=data_inputs,
        wait=False,
        job_name=f"test",
    )

if __name__ == "__main__":

    main()

Simply copy the two files into the same directory, set an appropriate execution role, select the framework/python version and run python launcher_test.py

Expected behavior

The expected behavior would be that the command line arguments are accessible within the shell script (like in the case of framework versions <= 2.4):

jprpro2cdr-algo-1-p0sbc | SM_HP_TEST_ARG1=test_val_1
jprpro2cdr-algo-1-p0sbc | SM_HP_TEST_ARG2=test_val_2
jprpro2cdr-algo-1-p0sbc | SM_HP_TEST_ARG3=test_val_3
jprpro2cdr-algo-1-p0sbc | SM_HP_MODEL_DIR=s3://sagemaker-us-east-1-xxx/test/model
jprpro2cdr-algo-1-p0sbc | PYTHONPATH=/opt/ml/code:/usr/local/bin:/usr/local/lib/python37.zip:/usr/local/lib/python3.7:/usr/local/lib/python3.7/lib-dynload:/usr/local/lib/python3.7/site-packages
jprpro2cdr-algo-1-p0sbc |
jprpro2cdr-algo-1-p0sbc | Invoking script with the following command:
jprpro2cdr-algo-1-p0sbc |
jprpro2cdr-algo-1-p0sbc | /bin/sh -c ./entrypoint_test.sh --model_dir s3://sagemaker-us-east-1-xxx/test/model --test_arg1 test_val_1 --test_arg2 test_val_2 --test_arg3 test_val_3
jprpro2cdr-algo-1-p0sbc |
jprpro2cdr-algo-1-p0sbc |
jprpro2cdr-algo-1-p0sbc | entrypoint invoked with arguments --model_dir s3://sagemaker-us-east-1-xxx/test/model --test_arg1 test_val_1 --test_arg2 test_val_2 --test_arg3 test_val_3
jprpro2cdr-algo-1-p0sbc | testing different cmd access methods:
jprpro2cdr-algo-1-p0sbc | $@
jprpro2cdr-algo-1-p0sbc | --model_dir s3://sagemaker-us-east-1-xxx/test/model --test_arg1 test_val_1 --test_arg2 test_val_2 --test_arg3 test_val_3
jprpro2cdr-algo-1-p0sbc | $*
jprpro2cdr-algo-1-p0sbc | --model_dir s3://sagemaker-us-east-1-xxx/test/model --test_arg1 test_val_1 --test_arg2 test_val_2 --test_arg3 test_val_3
jprpro2cdr-algo-1-p0sbc | $1
jprpro2cdr-algo-1-p0sbc | --model_dir
jprpro2cdr-algo-1-p0sbc | calling main python entrypoint
jprpro2cdr-algo-1-p0sbc | executing /usr/bin/python3 main_test.py --model_dir s3://sagemaker-us-east-1-xxx/test/model --test_arg1 test_val_1 --test_arg2 test_val_2 --test_arg3 test_val_3

Screenshots or logs

Here is the example log output when using tf 2.5 (i.e. no cmds available within the shell scripts as detailed by the blank outputs for $@, $*, $1, etc.):

ug52jqv4wv-algo-1-qdlyn | SM_HP_TEST_ARG1=test_val_1
ug52jqv4wv-algo-1-qdlyn | SM_HP_TEST_ARG2=test_val_2
ug52jqv4wv-algo-1-qdlyn | SM_HP_TEST_ARG3=test_val_3
ug52jqv4wv-algo-1-qdlyn | SM_HP_MODEL_DIR=s3://sagemaker-us-east-1-xxx/test/model
ug52jqv4wv-algo-1-qdlyn | PYTHONPATH=/opt/ml/code:/usr/local/bin:/usr/local/lib/python37.zip:/usr/local/lib/python3.7:/usr/local/lib/python3.7/lib-dynload:/usr/local/lib/python3.7/site-packages
ug52jqv4wv-algo-1-qdlyn |
ug52jqv4wv-algo-1-qdlyn | Invoking script with the following command:
ug52jqv4wv-algo-1-qdlyn |
ug52jqv4wv-algo-1-qdlyn | /bin/sh -c ./entrypoint_test.sh --model_dir s3://sagemaker-us-east-1-xxx/test/model --test_arg1 test_val_1 --test_arg2 test_val_2 --test_arg3 test_val_3
ug52jqv4wv-algo-1-qdlyn |
ug52jqv4wv-algo-1-qdlyn |
ug52jqv4wv-algo-1-qdlyn | entrypoint invoked with arguments
ug52jqv4wv-algo-1-qdlyn | testing different cmd access methods:
ug52jqv4wv-algo-1-qdlyn | $@
ug52jqv4wv-algo-1-qdlyn |
ug52jqv4wv-algo-1-qdlyn | $*
ug52jqv4wv-algo-1-qdlyn |
ug52jqv4wv-algo-1-qdlyn | $1
ug52jqv4wv-algo-1-qdlyn | calling main python entrypoint
ug52jqv4wv-algo-1-qdlyn | executing /usr/bin/python3 main_test.py

System information

A description of your system.

Additional context

N/A

unoebauer commented 2 years ago

It seems that a simple switch to asyncio.create_subprocess_exec fixes the problem. Below is a simple test script that can be used with the entrypoint_test.sh shell script to demonstrate this:

import asyncio
from asyncio.subprocess import PIPE
import os
import six
import sys

# tested with version 4.0.0 of sagemaker_training
from sagemaker_training import process
from sagemaker_training import (
    environment,
    errors,
)

# taken from sagemaker_training.process, slightly modified to work outside
# of process.ProcessRunner class
def _create_command(_user_entry_point, _args):
    args = [
        six.moves.shlex_quote(arg)  # pylint: disable=too-many-function-args
        for arg in _args
    ]
    return ["/bin/sh", "-c", "./%s %s" % (_user_entry_point, " ".join(args))]

# taken from sagemaker_training.process, slightly modified
async def run_async_new(cmd, processes_per_host, env, cwd, stderr, **kwargs):
    """Method responsible for launching asyncio subprocess shell
    Use asyncio gather to collect processed stdout and stderr

    Args:
        cmd (list): The command to be run
        processes_per_host (int): Number of processes per host
        env: os.environ
        cwd (str): The location from which to run the command (default: None).
            If None, this defaults to the ``code_dir`` of the environment.
        **kwargs: Extra arguments that are passed to the asyncio create subprocess constructor.

    Returns:
        return_code: Launched Process's return code
        output: Processed [stdout, stderr]
        asyncio.subprocess.Process: The asyncio process for the given command.

    Raises:
        error_class: If there is an exception raised when creating the process.
    """

    # use original cmd fragments and switch from asyncio.create_subprocess_shell to
    # asyncio.create_subprocess_exec
    proc = await asyncio.create_subprocess_exec(
        *cmd, env=env, cwd=cwd, stdout=PIPE, stderr=stderr, **kwargs
    )

    output = await asyncio.gather(
        process.watch(proc.stdout, processes_per_host),
        process.watch(proc.stderr, processes_per_host),
    )
    return_code = proc.returncode
    return return_code, output, proc

# taken from sagemaker_training.process, slightly modified
def create_new(
    cmd,
    error_class,
    processes_per_host,
    cwd=None,
    env=None,
    capture_error=False,
    **kwargs,
):
    """Spawn a process with asyncio for the given command.

    Args:
        cmd (list): The command to be run.
        error_class (cls): The class to use when raising an exception.
        cwd (str): The location from which to run the command (default: None).
            If None, this defaults to the ``code_dir`` of the environment.
        env: os.environ
        capture_error (bool): Whether or not to direct stderr to a stream
            that can later be read (default: False).
        **kwargs: Extra arguments that are passed to the asyncio create subprocess constructor.

    Returns:
        asyncio.subprocess.Process: The asyncio process for the given command.

    Raises:
        error_class: If there is an exception raised when creating the process.
    """
    try:
        stderr = PIPE if capture_error else None
        rc, output, proc = asyncio.run(
            run_async_new(
                cmd,
                processes_per_host,
                env=env or os.environ,
                cwd=cwd or environment.code_dir,
                stderr=stderr,
                **kwargs,
            )
        )
        return rc, output, proc
    except Exception as e:  # pylint: disable=broad-except
        six.reraise(error_class, error_class(e), sys.exc_info()[2])

def main():

    # define entrypoint and dummy arguments
    _user_entry_point = "entrypoint_test.sh"
    _args = [
        "--test_arg1",
        "test_val_1",
        "--test_arg2",
        "test_val_2",
        "--test_arg3",
        "test_val_3",
    ]

    print("Generating entrypoint execution cmd")
    _cmd = _create_command(_user_entry_point=_user_entry_point, _args=_args)
    print(f"cmd: {_cmd}\n")

    print(f"Executing cmd {_cmd} with asyncio.create_subprocess_shell (will fail):\n")
    process.create(
        _cmd,
        error_class=errors.ExecuteUserScriptError,
        processes_per_host=1,
        cwd=os.getcwd(),
        env=None,
        capture_error=False,
    )

    print(f"\nExecuting cmd {_cmd} with asyncio.create_subprocess_exec (will succeed):\n")
    create_new(
        _cmd,
        error_class=errors.ExecuteUserScriptError,
        processes_per_host=1,
        cwd=os.getcwd(),
        env=None,
        capture_error=False,
    )

if __name__ == "__main__":

    main()

I've opened PR #116 that proposes that change.

satishpasumarthi commented 2 years ago

Fixed in PR https://github.com/aws/sagemaker-training-toolkit/pull/122