aws / amazon-sagemaker-examples

Example 📓 Jupyter notebooks that demonstrate how to build, train, and deploy machine learning models using 🧠 Amazon SageMaker.
https://sagemaker-examples.readthedocs.io
Apache License 2.0
9.98k stars 6.73k forks source link

Mapping data_location to train script for scikit_bring_your_own #1386

Open ksachdeva11 opened 4 years ago

ksachdeva11 commented 4 years ago

Hi,

How can we pass s3 path after the model is hosted on sagemaker and use this path to fetch the csv in the train code? How can we map this data_location directly in the train code?

account = sess.boto_session.client('sts').get_caller_identity()['Account']
region = sess.boto_session.region_name
image = '{}.dkr.ecr.{}.amazonaws.com/sagemaker-decision-trees:latest'.format(account, region)

tree = sage.estimator.Estimator(image,
                       role, 1, 'ml.c4.2xlarge',
                       output_path="s3://{}/output".format(sess.default_bucket()),
                       sagemaker_session=sess)

tree.fit(data_location)

train code -

def train():
    print('Starting the training.')
    try:
        # Read in any hyperparameters that the user passed with the training job
        with open(param_path, 'r') as tc:
            trainingParams = json.load(tc)

        # Take the set of files and read them all into a single pandas dataframe
        input_files = [ os.path.join(training_path, file) for file in os.listdir(training_path) ]
        if len(input_files) == 0:
            raise ValueError(('There are no files in {}.\n' +
                              'This usually indicates that the channel ({}) was incorrectly specified,\n' +
                              'the data specification in S3 was incorrectly specified or the role specified\n' +
                              'does not have permission to access the data.').format(training_path, channel_name))
        raw_data = [ pd.read_csv(file, header=None) for file in input_files ]
        train_data = pd.concat(raw_data)

        # labels are in the first column
        train_y = train_data.ix[:,0]
        train_X = train_data.ix[:,1:]

        # Here we only support a single hyperparameter. Note that hyperparameters are always passed in as
        # strings, so we need to do any necessary conversions.
        max_leaf_nodes = trainingParams.get('max_leaf_nodes', None)
        if max_leaf_nodes is not None:
            max_leaf_nodes = int(max_leaf_nodes)

        # Now use scikit-learn's decision tree classifier to train the model.
        clf = tree.DecisionTreeClassifier(max_leaf_nodes=max_leaf_nodes)
        clf = clf.fit(train_X, train_y)

        # save the model
        with open(os.path.join(model_path, 'decision-tree-model.pkl'), 'w') as out:
            pickle.dump(clf, out)
        print('Training complete.')
    except Exception as e:
        # Write out an error file. This will be returned as the failureReason in the
        # DescribeTrainingJob result.
        trc = traceback.format_exc()
        with open(os.path.join(output_path, 'failure'), 'w') as s:
            s.write('Exception during training: ' + str(e) + '\n' + trc)
        # Printing this causes the exception to be in the training job logs, as well.
        print('Exception during training: ' + str(e) + '\n' + trc, file=sys.stderr)
        # A non-zero exit code causes the training job to be marked as Failed.
        sys.exit(255)

if __name__ == '__main__':
    train()

    # A zero exit code causes the job to be marked a Succeeded.
    sys.exit(0)

Does fit directly call train method when we execute tree.fit(data_location) ?

chuyang-deng commented 4 years ago

Are you intending to use the data for inference or for training? Since you said after the model is hosted on SageMaker.

If it is for inference, the inference payload should eventually be downloaded to a local path that Python SDK can find.

If it is for training, your training script is using the same input as you specified in the fit method, and SageMaker will download the training data to designated paths in the container. For more information of the paths on different input format, you can refer to this page: https://docs.aws.amazon.com/sagemaker/latest/dg/your-algorithms-training-algo-running-container.html