aws / sagemaker-experiments

Experiment tracking and metric logging for Amazon SageMaker notebooks and model training.
Apache License 2.0
125 stars 36 forks source link

Set tags for TrialComponent when creating it using the Tracker class #171

Open redfungus opened 1 year ago

redfungus commented 1 year ago

Is your feature request related to a problem? Please describe. Yes, since the Tracker class is the suggested way to create TrialComponent objects, it would be nice to be able to set tags when using Tracker.create(...). Currently, TrialComponent.create(...) supports setting tags but there is no way to pass it through to the Tracker.create(...) method.

Describe the solution you'd like I would like to be able to call Tracker.create(..., tags=my_tags) and then for the tags to be passed to TrialComponent. e.g. TrialComponent.create(..., tags=my_tags)`

Example solution in tracker.py:

@classmethod
    def create(
        cls,
        base_trial_component_name="TrialComponent",
        display_name=None,
        artifact_bucket=None,
        artifact_prefix=None,
        boto3_session=None,
        tags=None,
        sagemaker_boto_client=None,
    ):
        """Create a new ``Tracker`` by creating a new trial component.
        Note that `log_metric` will _not_ work when tracker is created this way.
        Examples
            .. code-block:: python
                from smexperiments import tracker
                my_tracker = tracker.Tracker.create()
        Args:
            base_trial_component_name: (str,optional). The name of the trial component resource that
                will be appended with a timestamp. Defaults to "TrialComponent".
            display_name: (str, optional). The display name of the trial component to track.
            artifact_bucket: (str, optional) The name of the S3 bucket to store artifacts to.
            artifact_prefix: (str, optional) The prefix to write artifacts to within ``artifact_bucket``
            boto3_session: (boto3.Session, optional) The boto3.Session to use to interact with AWS services.
                If not specified a new default boto3 session will be created.
            sagemaker_boto_client: (boto3.Client, optional) The SageMaker AWS service client to use. If not
                specified a new client will be created from the specified ``boto3_session`` or default
                boto3.Session.
        Returns:
            Tracker: The tracker for the new trial component.
        """
        boto3_session = boto3_session or _utils.boto_session()
        sagemaker_boto_client = sagemaker_boto_client or _utils.sagemaker_client()

        tc = trial_component.TrialComponent.create(
            trial_component_name=_utils.name(base_trial_component_name),
            display_name=display_name,
            tags=tags
            sagemaker_boto_client=sagemaker_boto_client,
        )

        # metrics require the metrics agent running on training job hosts and in which case the load
        # method should be used because it loads the trial component associated with the currently
        # running training job
        metrics_writer = None

        return cls(
            tc,
            metrics_writer,
            _ArtifactUploader(tc.trial_component_name, artifact_bucket, artifact_prefix, boto3_session),
            _LineageArtifactTracker(tc.trial_component_arn, sagemaker_boto_client),
        )

Describe alternatives you've considered Creating a TrialComponent manually with my required tags and then using the Tracker.load() method to load the TrialComponent.

Additional context This feature is specifically useful for cases where access permissions will be limited according to resource tags.