cmu-delphi / epidatpy

Delphi Epidata API Python Client
MIT License
0 stars 1 forks source link

improve endpoint argument validation: use Pydantic #32

Open dshemetov opened 2 months ago

dshemetov commented 2 months ago

We can remove a lot of boilerplate argument validation code by using Pydantic's validate_call and type hints. It's fast and has good default messages. See the proof of concept script below.

# Testing Pydantic for validating work.
#
# My profiling results below indicate that Pydantic is good for validating input
# arguments to functions and the returns of small JSON messages. Simply using
# Pandas may be better for larger JSON messages, especially if we expect to
# eventually output a DataFrame anyway.
#
# So my recommendation is to use `validate_call` for endpoint function arguments.
#

import cProfile
import datetime
from typing import List, Literal, Optional, Union

import requests
from epiweeks import Week
from pydantic import (
    BaseModel,
    ConfigDict,
    PositiveFloat,
    condate,
    field_validator,
    validate_call,
)

GeoType = Literal["nation", "msa", "hrr", "hhs", "state", "county"]
TimeType = Literal["day", "week"]
EpiDateLike = Union[int, str, condate(gt=datetime.date(1990, 1, 1)), Week]

# The default error takes some getting used to: the validation error message
# follows a positional index or a keyword argument name.
# https://docs.pydantic.dev/2.8/errors/errors/
@validate_call(config=dict(arbitrary_types_allowed=True))
def test_function(a: int, b: PositiveFloat, c: GeoType, d: EpiDateLike) -> float:
    return f"{a + b} {c} {d}"

# Casts the first argument to int and errors on the next 3.
test_function(5.0, -5, c="hey", d=datetime.date(1989, 4, 5))
# Casts the first argument to int and errors on the next 2.
test_function(5.0, -5, c="hey", d=19890405)
# Casts the first argument to int and errors on the next 2.
test_function(5.0, -5, c="hey", d=Week(1989, 14))

# Mutually exclusive arguments require some extra work.
# https://stackoverflow.com/a/72087084
class MyModel(BaseModel):
    a: Optional[str]
    b: Optional[str]

    @field_validator("b", always=True)
    def mutually_exclusive(cls, v, values):
        if values["a"] is not None and v:
            raise ValueError("'a' and 'b' are mutually exclusive.")

        return v

# You can create a model for a JSON row and validate it.
class Covidcast(BaseModel):
    model_config = ConfigDict(arbitrary_types_allowed=True)

    source: str
    signal: str
    geo_type: GeoType
    geo_value: str
    time_type: TimeType
    time_value: EpiDateLike
    issue: EpiDateLike
    lag: int
    value: float
    stderr: float
    sample_size: int
    direction: Union[float, None]
    missing_value: int
    missing_stderr: int
    missing_sample_size: int

row = """{"geo_value":"us","signal":"smoothed_cli","source":"fb-survey","geo_type":"nation","time_type":"day","time_value":20210405,"direction":null,"issue":20210410,"lag":5,"missing_value":0,"missing_stderr":0,"missing_sample_size":0,"value":0.6758323,"stderr":0.0148258,"sample_size":244046.0}"""
vrow = Covidcast.model_validate_json(row)

# You can create a model for the whole JSON response, consisting of rows above and validate them all.
class Response(BaseModel):
    result: Union[str, int]
    message: str
    epidata: List[Covidcast]

data = requests.get(
    "https://api.delphi.cmu.edu/epidata/covidcast/?data_source=fb-survey&signals=smoothed_cli&time_type=day&time_values=20210405-20210410&geo_type=nation&geo_values=us"
)
vdata = Response.model_validate_json(data.text)

# Profiling constructing Pandas DataFrames from the validated JSON data.
vdata2 = vdata
vdata2.epidata = vdata2.epidata * 10**5

cProfile.run("pd.DataFrame([s.model_dump() for s in vdata2.epidata])")  # around 3.4s
cProfile.run("pd.DataFrame.from_records(vdata2.model_dump()['epidata'])")  # 0.62s