awslabs / sagemaker-battlesnake-ai

Starter pack to build an AI for Battlesnake with Amazon Sagemaker more content on wiki:
https://github.com/awslabs/sagemaker-battlesnake-ai/wiki
Apache License 2.0
89 stars 53 forks source link

SageMaker SDK argument incompatible in 2_PolicyTraining.ipynb #40

Closed michaelhsieh42 closed 3 years ago

michaelhsieh42 commented 3 years ago

2_PolicyTraining.ipynb breaks in SageMaker Studio with "SageMaker JumpStart Tensorflow 1.0" image due to SageMaker SDK version imcompatibility.

The notebook and solution was deployed through SageMaker JumpStart. The 7th cell would break due to argument missing.

estimator.fit()
---------------------------------------------------------------------------
AttributeError                            Traceback (most recent call last)
<timed exec> in <module>

/opt/conda/envs/sagemaker-soln/lib/python3.7/site-packages/sagemaker/rl/estimator.py in __init__(self, entry_point, toolkit, toolkit_version, framework, source_dir, hyperparameters, image_uri, metric_definitions, **kwargs)
    145             :class:`~sagemaker.estimator.EstimatorBase`.
    146         """
--> 147         self._validate_images_args(toolkit, toolkit_version, framework, image_uri)
    148 
    149         if not image_uri:

/opt/conda/envs/sagemaker-soln/lib/python3.7/site-packages/sagemaker/rl/estimator.py in _validate_images_args(cls, toolkit, toolkit_version, framework, image_uri)
    389                 raise AttributeError(
    390                     "Please provide `{}` or `image_uri` parameter.".format(
--> 391                         "`, `".join(not_found_args)
    392                     )
    393                 )

AttributeError: Please provide `toolkit`, `toolkit_version`, `framework` or `image_uri` parameter.

This is due to the input arguments to the estimator are obsolete after SageMaker Python SDK v1->v2 upgrade.

The SageMaker JumpStart Tensorflow 1.0 kernel has sdk version:

sagemaker.__version__
'2.45.0'

which takes image_uri instead of image_name.

image_name = '462105765813.dkr.ecr.{region}.amazonaws.com/sagemaker-rl-ray-container:ray-0.8.2-tf-{device}-py36'.format(region=region, device=device)
estimator = RLEstimator(entry_point="train-mabs.py",
                        source_dir='training/training_src',
                        dependencies=["training/common/sagemaker_rl", "inference/inference_src/", "../BattlesnakeGym/"],
                        image_name=image_name,
                        role=role,
                        training_instance_type=instance_type,
                        training_instance_count=1,
...
...
michaelhsieh42 commented 3 years ago

Sorry for premature issue. The code in repo does have image_uri, it's the code that is deployed in SageMaker JumpStart.