Open utterances-bot opened 6 months ago
thanks for sharing! easy to read/digest... I would be interested to see more comparisons with complex logic.
For instance, how it affects something like this?
import logging
import re
from enum import Enum
from typing import List, Optional, Union
from pydantic import BaseModel, validator
log = logging.getLogger("uvicorn")
class InstanceType(Enum):
g5_x = "ml.g5.xlarge"
g5_2x = "ml.g5.2xlarge"
class SagemakerBasicConfig(BaseModel):
instance_type: str
instance_count: int
role: str
@validator("instance_type")
def validate_instance_type(cls, instance_type: str):
return InstanceType(instance_type).value
class SagemakerTrainingConfig(SagemakerBasicConfig):
entry_point: str
source_dir: str
base_job_name: Optional[str]
use_spot: Optional[bool]
max_wait: Optional[int]
max_run: Optional[int]
@validator("base_job_name")
def validate_base_job_name(cls, base_job_name: str):
"""
Example of Job Name
---
Valid Example: bi-head-train-predict
Invalid Example: bi_head_train_predict
"""
if base_job_name:
pattern = r"^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$"
if not re.match(pattern, base_job_name):
raise ValueError(
"Invalid base_job_name. It must follow the pattern ^[a-zA-Z0-9](-*[a-zA-Z0-9]){0,62}$"
)
return base_job_name
def to_dict(self):
"""
The SageMaker API cannot accept certain parameters with values set as None due to which
we are removing them here.
Example:
botocore.exceptions.ParamValidationError: Parameter validation failed:
Invalid type for parameter StoppingCondition.MaxRuntimeInSeconds, value: None, type: <class 'NoneType'>, valid types: <class 'int'>
"""
def remove_none_values(data):
if isinstance(data, dict):
return {
k: remove_none_values(v) for k, v in data.items() if v is not None
}
elif isinstance(data, list):
return [remove_none_values(item) for item in data if item is not None]
else:
return data
return remove_none_values(self.dict())
class XGBRankerTrainingConfig(SagemakerTrainingConfig):
framework_version: str
class HuggingFaceHyperParameters(BaseModel):
epochs: int
model_name: str
fp16: bool
warmup_steps: int
learning_rate: float
train_batch_size: int
eval_batch_size: int
eval_steps: Optional[int]
temperature: Optional[float]
weight_multiply_factor: Optional[int]
class BiHeadClassifierHyperParameters(BaseModel):
epochs: int
model_name: str
fp16: bool
warmup_steps: int
eval_steps: Optional[int]
train_batch_size_sequence_classifier: int
eval_batch_size_sequence_classifier: int
train_batch_size_token_classifier: int
eval_batch_size_token_classifier: int
learning_rate_sequence_classification: float
learning_rate_token_classification: float
class S3Data(BaseModel):
"""
The estimator fit method can take many path arguments as input
Example:
For bi-head model, we need for 4 paths
- train
- test
- train_covered_words
- test_covered_words
The exact paths look like this:
s3://{bucket_name}/{s3_prefix}/{file_path}
The s3 file path generation will be handled by the train function.
"""
bucket_name: str
s3_prefix: str
file_paths: List[str]
class PyTorchVersion(Enum):
V1_13 = "1.13"
class TransformersVersion(Enum):
V4_26 = "4.26"
class HuggingFaceTrainingConfig(SagemakerTrainingConfig):
transformers_version: TransformersVersion
pytorch_version: PyTorchVersion
py_version: str
hyperparameters: Union[BiHeadClassifierHyperParameters, HuggingFaceHyperParameters]
checkpoint_local_path: Optional[str]
checkpoint_s3_uri: Optional[str]
class HuggingFaceTrainerConfig(BaseModel):
huggingface_config: HuggingFaceTrainingConfig
data: S3Data
_valid_combinations: dict[str, tuple] = {
# pytorch_version: (transfomrers_version)
"1.13": ("4.26")
}
@validator("huggingface_config")
def validate_transformers_pytorch_compatibility(
cls, huggingface_config: HuggingFaceTrainingConfig
):
"""
Verify that the pytorch and transformers version are compatible with each other as per
SageMaker images. Plus, convert the Enumerations to Strings because that's the format
expected by the HuggingFace library.
"""
huggingface_config.pytorch_version = huggingface_config.pytorch_version.value
huggingface_config.transformers_version = (
huggingface_config.transformers_version.value
)
if not huggingface_config.pytorch_version in cls._valid_combinations:
raise ValueError("PyTorch version not supported.")
if (
huggingface_config.transformers_version
not in cls._valid_combinations[huggingface_config.pytorch_version]
):
raise ValueError(
"PyTorch Version and Transformers Version are incompatible."
)
return huggingface_config
class XGBRankerTrainerConfig(BaseModel):
xgb_ranker_training_config: XGBRankerTrainingConfig
data: S3Data
where we try initiating HuggingFaceTrainingConfig
@ayushxx7 Thank you for your comment! Given how much faster v2 is than v1 for such a simple task, I'd expect it to do the same for your complex example.
Pydantic 1 vs 2: A Speed Comparison
Pydantic v2 claims to be 5-50x faster than v1. Is this accurate? Let's put this to the test!
https://janhendrikewers.uk/pydantic-1-vs-2-a-benchmark-test.html