aws / sagemaker-core

Apache License 2.0
5 stars 3 forks source link

transform() not handling nested objects #23

Closed benieric closed 4 months ago

benieric commented 4 months ago

There may be an issue with the transform() method where it is not handling nested objects correctly.

When I create a TrainingJob training_job.model_artifacts should be an object but it is a dict. As a result I can not access the s3_model_artifacts attribute and get error like below

image

benieric commented 4 months ago

Seems like the issues is actually only with attributes that are like:

model_artifacts: Optional[ModelArtifacts] = Unassigned()

If I do TrainingJob.get() and make the attribute required like:

model_artifacts: ModelArtifacts

cls(**transformed_response) will create nested objects

image
benieric commented 4 months ago

Could be an issue with pydantic seems like they handle conversion for nested attributes: https://docs.pydantic.dev/latest/concepts/models/#nested-attributes

but for some reason when the attribute is Optional is not handled properly

benieric commented 4 months ago

https://stackoverflow.com/questions/62267544/generate-pydantic-model-from-a-dict

benieric commented 4 months ago

Looks like using Model.model_validate(data) instead of cls(**transformed_response) may be working as a fix

image
benieric commented 4 months ago

The actual issue seems to be here: https://github.com/aws/sagemaker-core/blob/633a0d43cf9a74334e4ccb9d350478b93fe4bd2b/src/sagemaker_core/code_injection/codec.py#L235

when calling TrainingJob.refresh() during wait

when an object is already instantiated we assign the value to a dict for nested objects instead

benieric commented 4 months ago

This is a small miss on pydantic where they should validate assignments by defaults apparently they do not. So we were facing issue where when we updated attribute with dict it would just assign dict value to attribute instead of validating and converting to model object like in the base case on initial create. Fix is simple config change: https://stackoverflow.com/questions/62025723/how-to-validate-a-pydantic-object-after-editing-it

model_config = ConfigDict(protected_namespaces=(), validate_assignment=True)
benieric commented 4 months ago

Looks like using Model.model_validate(data) instead of cls(**transformed_response) may be working as a fix

image

This was working because I was doing TrainingJob.get() the issue was happing in the refresh() flow