Open WardBrian opened 8 months ago
Right now an argument is specified as UnsizedType.autodifftype * UnsizedType.t
I think we could define another type which looks like:
type argtype =
| Concrete of UnsizedType.t
| Predicate of (UnsizedType.t -> bool) (* think of this like a C++ concept *)
We'd need to rewrite a fair amount of stuff to handle this but I don't think it would be too bad. The argument type of size
would then be something like:
Predicate (function UArray _ -> true | _ -> false)
Unfortunately this becomes very difficult to print. We could have Predicate
hold both UnsizedType.t -> bool
and a string
which we use as it's pretty-printed name, so like Predicate ((function UArray _ -> true | _ -> false), "array[] T")
I think we could do this with an API like
(*Function that depends on two types*)
let matching_array_types x y =
match x, y with
| UnsizedType.UArray _, UnsizedType.UArray _ when x == y -> Valid
| _ -> Invalid
Fun ("append_array", matching_array_types)
(*Anything is valid*)
let all_valid _ = Valid
Fun ("size", all_valid)
(*For the distributions*)
let is_vectorizable x =
match x with
| UnsizedType.UArray (UReal, UInt) | UVector | URowVector | UReal | UInt -> true
| _ -> false
in
let is_all_vectorizable x, y, z =
if is_vectorizable x && is_vectorizable y && is_vectorizable z the
Valid
else
Invalid
in
Fun ("normal_lpdf", is_all_vectorizable)
Some further thoughts:
For Stan_math_signatures.ml
, we could update the types used to be
type signature =
UnsizedType.returntype * UnsizedType.argumentlist * Mem_pattern.t
type predicate_signature =
| Concrete of signature list
| Calculated of
( UnsizedType.argumentlist
-> ( UnsizedType.returntype * Mem_pattern.t
, function_mismatch ) result) * string
where function_mismatch
is the type currently used in SignatureMismatch.ml
Our hashmap then stores a single predicate_signature
, rather than the current signature list
.
The bigger issue is actually handling typechecking, especially in Environment.ml
.
Calculated
function signatures have no type (in the sense of they cannot be represented by a single UnsizedType.t
)So, the best option I can see would involve there being two "Environments", one for the library and one for everything user defined. I've seen these called the "static" environment, for things built-in to the language/library, and "dynamic" environment, for things users define.
The "static" environment would essentially just be the Stan_math_signatures mapping, and the dynamic environment could stay as currently set up in Environment.ml
. Things like overloading become a bit more complicated, since it is now a two-step process: Look up in the static environment, and only if that fails (either because it is missing or had the wrong type in the static environment), look it up in the dynamic environment.
Detecting if a new function is a valid overload is also a bit trickier here: we can no longer just make sure it isn't already in the big list. The best solution I can think of is to try to pass the newly-user-declared arguments to the static environment, and if and only if it fails does that mean this overload is valid. I'm not sure if this is entirely sound, and getting a good error message out may be tricky.
This would allow Array functions to be able to accept arrays of tuples, e.g.
size
And also allow vectorization of things like a 7-parameter function, which would currently be too expensive to enumerate