openforcefield / openff-models

Helper classes for Pydantic compatibility in the OpenFF stack
MIT License
3 stars 2 forks source link

Support arrays with type checking #7

Open lilyminium opened 2 years ago

lilyminium commented 2 years ago

Thanks for starting this up @mattwthompson, looking forward to just using openff.models in everything!

It would be nice to be able to specify array typing too (and possibly shape checking). I implemented some array dtyping once, but never got round to doing shapes.

from typing import Any

import numpy as np

class ArrayMeta(type):
    def __getitem__(cls, T):
        return type("Array", (Array,), {"__dtype__": T})

class Array(np.ndarray, unit.Quantity, metaclass=ArrayMeta):
    """A typeable numpy array"""

    @classmethod
    def __get_validators__(cls):
        yield cls.validate_type

    @classmethod
    def validate_type(cls, val):
        from openff.units import unit
        from openff.units.units import Unit

        dtype = getattr(cls, "__dtype__", Any)
        if dtype is Any:
            dtype = None

        if isinstance(dtype, Unit):
            # assign units
            val = unit.Quantity(val, dtype)
            # coerce into np.ndarray
            val = unit.Quantity.from_list(val)
            return val
        return np.asanyarray(val, dtype=dtype)

In practice, this looks like:

In:

class Model(BaseModel):
    a: Array[unit.kelvin]
    b: Array[float]
    c: Array[int]

    class Config:
        arbitrary_types_allowed = True
        validate_assignment = True

In:

int_array = np.arange(3).astype(int)
x = Model(a=int_array, b=int_array, c=int_array)
print(x.a)
print(x.b)
print(x.c)

Out:

[0.0 1.0 2.0] kelvin
[0. 1. 2.]
[0 1 2]

And an error:

In:

x.a = 3 * unit.kelvin

Out:

ValidationError                           Traceback (most recent call last)
Input In [130], in <cell line: 1>()
----> 1 x.a = 3 * unit.kelvin

File ~/anaconda3/envs/gnn-charge-models-test/lib/python3.9/site-packages/pydantic/main.py:380, in pydantic.main.BaseModel.__setattr__()

ValidationError: 1 validation error for Model
a
  object of type 'int' has no len() (type=type_error)

Currently it strips types if there are any.

x.b = 3 * unit.kelvin
/var/folders/rv/j6lbln6j0kvb5svxj8wflc400000gn/T/ipykernel_18158/1139253006.py:32: UnitStrippedWarning: The unit of the quantity is stripped when downcasting to ndarray.
  return np.asanyarray(val, dtype=dtype)

And type checking still runs:

x.a = int_array * unit.m
ValidationError: 1 validation error for Model
a
  Cannot convert from 'meter' ([length]) to 'kelvin' ([temperature]) (type=type_error.dimensionality; units1=meter; units2=kelvin; dim1=[length]; dim2=[temperature]; extra_msg=)