arviz-devs / arviz-base

Base ArviZ features and converters
https://arviz-base.readthedocs.io/
Apache License 2.0
0 stars 2 forks source link

validation module #21

Open amaloney opened 2 months ago

amaloney commented 2 months ago

See https://github.com/arviz-devs/arviz-plots/issues/83 for more info. The goal is to create modules to verify inputs from users.

OriolAbril commented 2 months ago

First step might be porting/adapting https://github.com/arviz-devs/arviz-stats/blob/main/src/arviz_stats/validate.py. Also, there are validator functions available https://github.com/arviz-devs/arviz-base/blob/main/src/arviz_base/rcparams.py which are triggered every time you attempt to set the value of an rcParam and can probably help with some of the validation steps

amaloney commented 2 months ago

Do we want to use a tool for this? Perhaps something like pydantic?

OriolAbril commented 2 months ago

Probably not unless it is something extremely lightweight and well supported in all platforms. pydantic I think would fit the bill. I have not used it yet though so not sure.

amaloney commented 2 months ago

I think I can investigate its use here. I'll report back with more info

amaloney commented 1 month ago

I have given this a lot of thought, and we can use pydantic, but we can also "roll our own". For example:

from dataclasses import dataclass

# Create a base class that will run validation methods on a `dataclass` init
class Validate:
    def __post_init(self):
        for name, field in self.__dataclass_fields__.items():
            if (method := getattr(self, f"validate_{name}", None)):
                setattr(self, name, method(getattr(self, name), field=field))

# Create a "validator" for a specific input.
@dataclass
class HDIProbability(Validate):
    hdi: float
    def validate_hdi(self, value, **_) -> float:
        if not isinstance(value,` (float, int)):
            value = float(value)
        if value < 0:
            raise ValueError("hdi must be greater than 0.")
        if value > 1:
            raise ValueError("hdi must be less than 0.")
        return float(value)

# Use the validation in a method.
def probability(hdi: float):
    validator = HDIProbability(hdi)
    hdi = validator.hdi
    return hdi

# Examples
>>> probability(hdi=0.5)
0.5
>>> probability(hdi="1")
1.0
>>> probability(hdi=2)
ValueError: hdi must be less than 1.
>>> probability(hdi=-1)
ValueError: hdi must be greater than 0.
>>> probability(hdi="zort")
ValueError: could not convert string to float: 'zort'

We can use this pattern to construct inputs that are valid for methods. As an example

def psense(
    dt,
    group="prior",
    sample_dims=None,
    group_var_names=None,
    group_coords=None,
    var_names=None,
    coords=None,
    filter_vars=None,
    delta=0.01,
):

becomes

from typing import Literal
import xarray as xr
class ValidatePSense(Validate):
    dt: xr.datatree
    group: Literal["prior", "likelihood"] = "prior"
    sample_dims: list[str] = None
    group_var_names: str = None
    group_coords: dict = None
    coords: = None
    filter_vars: = None
    delta: float = 0.01
    def validate_dt(self, value, **_) -> xr.Datatree:
        if not isinstance(value, xr.Datatree):
            raise TypeError("The given `dt` was not an xarray Datatree.")
        return value
    def validate_group(self, value, **_) -> str:
        if not isinstance(group, str):
            raise TypeError(...)
        if value not in ["prior", "likelihood"]:
            raise ValueError(...)
    ...
def psense(
    dt: xr.Datatree,
    group: Literal["prior", "likelihood"] = "prior",
    ...,
):
    valid = ValidatePSense(dt, group, ...)
    dt = ivalid.dt
    group = valid.group
    ...

We can construct different objects we want to validate, eg a single class for group and use inheritance for a class that is used for validating input for a single method.

I'm not convinced if using pydantic would make this task more or less easier. I do know that pydantic is used by a lot of people so using it would reduce the development barrier for others.

@OriolAbril @sethaxen @aloctavodia (and anyone else) add a thumbs up for using pydantic or a heart for rolling our own.

OriolAbril commented 1 month ago

Given how much overlap there is between different functions I imagined something more along the lines of:

# in validate module
def sample_dims(sample_dims):
    if sample_dims is None:
        sample_dims = rcParams["data.sample_dims"]
    if isinstance(sample_dims, Hashable):
        return [sample_dims]
    if isinstance(group, Sequence):
        if not all(isinstance(dim, Hashable) for dim in sample_dims):
            raise TypeError()
        return sample_dims
    raise TypeError

# in psense file
from arviz_base import validate

# this one is actually specific to psense
def validate_psense_group(group):
    if not isinstance(group, str):
        raise TypeError(...)
    if group not in ["prior", "likelihood"]:
        raise ValueError(...)
    return group

def psense(
    dt: xr.DataTree,
    group: Literal["likelihood", "prior"]="prior",
    sample_dims: Hashable | Sequence[Hashable] | None=None,
    group_var_names=None,
    group_coords=None,
    var_names=None,
    coords=None,
    filter_vars=None,
    delta=0.01,
):
    dt = validate.dt(dt)
    group = validate_psense_group(group)
    sample_dims = validate.sample_dims(sample_dims)

I don't think I am yet getting what the classes bring to the game. EDIT: trying to make sure I understand the options before voting, not really saying we need to do this thing above.

amaloney commented 1 month ago

Depends on what you want. I like what is shown above, and it definitely falls in the "roll our own" camp with methods and no classes. It allows for reuse, as we can use the sample_dims method anywhere after importing it from the validate module. It basically comes down to style (in my opinion). Do we want to compose validation objects (classes holding type and validation methods for several inputs), or have a validation method for each input.

Ultimately they are going to do the same thing.

Edit: also remember that pydantic is used all over the place, and it is a class based validation tool, https://docs.pydantic.dev/latest/#who-is-using-pydantic

amaloney commented 1 month ago

Looks like hypothesis allows you to use types as strategies when testing: https://stackoverflow.com/questions/70396266/how-to-generate-test-samples-with-hypothesis-directly-from-dataclasses

so a dataclass object (which is a type) can be used as a way for testing. I'm not 100% on how to use it in this case, but it seems like we can create a dataclass for the psense method, and then reuse it as a strategy for testing inputs. https://hypothesis.readthedocs.io/en/latest/data.html#hypothesis.strategies.from_type

amaloney commented 1 month ago

And yes hypothesis does give you the ability to test with dataclass types easily.

from dataclasses import dataclass
from typing import Literal
from hypothesis import given
from hypothesis.strategies import from_type

class Validate:
    def __post_init(self):
        for name, field in self.__dataclass_fields__.items():
            if (method := getattr(self, f"validate_{name}", None)):
                setattr(self, name, method(getattr(self, name), field=field))

@dataclass
class PSense(Validate):
    group: Literal["likelihood", "prior"] = "prior"
    delta: float = 0.01
    def validate_group(self, value, **_) -> str:
        if value not in ["likelihood", "prior"]:
            raise ValueError("group must be one of ['likelihood', 'prior'].")
        return value
     def validate_delta(self, value, **_) -> float:
         if not isinstance(value, float):
             raise TypeError("delta must be a float.")
         if value >= 1:
             raise ValueError("delta must be less than 1.")
         return value

# Simple version
def psense(group: Literal["likelihood", "prior"] = "prior", delta: float = 0.01):
    pass

# Test
@given(inputs=from_type(PSense))
def test_psense(inputs):
    assert inputs.group in ["likelihood", "prior"]
    assert inputs.delta < 1

The above successfully found issues, one where delta=nan, which is a float, and the other is when we do not correctly raise the ValueError when delta is greater than 1.

  | ExceptionGroup: Hypothesis found 2 distinct failures. (2 sub-exceptions)
  +-+---------------- 1 ----------------
    | AssertionError
    | Falsifying example: test_psense(
    |     inputs=PSense(group='prior', delta=nan),
    | )
    +---------------- 2 ----------------
    | ValueError: delta must be less than 1.
    | while generating 'inputs' from builds(PSense, delta=one_of(just(0.01), floats()),
    | group=one_of(just('prior'), sampled_from(['prior', 'likelihood'])))
    +------------------------------------
amaloney commented 1 month ago

@OriolAbril I like your ideas a great deal, so I expanded on them more. Your idea is to create validation methods (that we can reuse) is great, so I kept this. I expanded the idea to create a decorator we can apply to a ArviZ methods we want to validate the inputs for. This decorator is used as an example below.


import inspect
from functools import wraps
from itertools import zip_longest
from typing import Hashable, Iterable, Sequence
from arviz_base.rcparams import rcParams

# Validation method
def validate_sample_dims(value, func_name):
    if value is None:
        value = rcParams["data.sample_dims"]
    if not isinstance(value, Iterable):
        raise TypeError(f"sample_dims of {func_name!r} must be the names of dimensions.")
    if not all(isinstance(v, Hashable) for v in value):
        raise TypeError(f"sample_dims of {func_name!r} must be names of dimensions, or tuples of dimensions.")
    if isinstance(value, list):
        if not all(isinstance(v, str) for v in value):
            raise TypeError(f"sample_dims of {func_name!r} must be the name of dimension.")
    if not isinstance(value, list):
        value = [value]
    return value

# decorator for psense (just an example)
def validate_psense(func):
    signature = inspect.signature(func)
    @wraps(func)
    def wrapper(*args, **kwargs):
        parameters = dict(zip_longest(signature.parameters, args))
        parameters.update(kwargs)
        sample_dims = parameters["sample_dims"]
        kwargs["sample_dims"] = validate_sample_dims(sample_dims, func.__name__)
        return func(*args, **kwargs)
    return wrapper

# decorated validator
@validate_psense
def psense(random, sample_dims: Hashable | Sequence[Hashable] | None = None):
    print(sample_dims)

# Use of the method
>>> psense(1, sample_dims=1)
TypeError: sample_dims of 'psense' must be the names of dimensions.

>>> psense(1, sample_dims="1")  # maybe there is a dimension named 1?
["1"]

>>> psense(1, sample_dims="chain")
["chain"]

>>> psense(1, sample_dims=None)
[("chain", "draw")]

>>> psense(1)
[("chain", "draw")]

I still like the idea of using dataclasses for testing purposes, since we can create the class and give it to hypothesis in a straight forward manner. That's for a different discussion though.