Closed yihozhang closed 10 months ago
@oflatt This is a great suggestion. I tweaked the interface a little bit and this allows us to introduce a new primitive as follows (see src/lib.rs
):
struct InnerProduct {
ele: Arc<I64Sort>,
vec: Arc<VecSort>,
}
impl PrimitiveLike for InnerProduct {
fn name(&self) -> symbol_table::GlobalSymbol {
"inner-product".into()
}
fn get_type_constraints(&self) -> Box<dyn crate::constraint::TypeConstraint> {
SimpleTypeConstraint::new(
self.name(),
vec![self.vec.clone(), self.vec.clone(), self.ele.clone()],
)
.into_box()
}
fn apply(&self, values: &[crate::Value], _egraph: &EGraph) -> Option<crate::Value> {
let mut sum = 0;
let vec1 = Vec::<Value>::load(&self.vec, &values[0]);
let vec2 = Vec::<Value>::load(&self.vec, &values[1]);
assert_eq!(vec1.len(), vec2.len());
for (a, b) in vec1.iter().zip(vec2.iter()) {
let a = i64::load(&self.ele, a);
let b = i64::load(&self.ele, b);
sum += a * b;
}
sum.store(&self.ele)
}
}
#[test]
fn test_user_defined_primitive() {
let mut egraph = EGraph::default();
egraph
.parse_and_run_program(
"
(sort IntVec (Vec i64))
",
)
.unwrap();
let i64_sort: Arc<I64Sort> = egraph.get_sort().unwrap();
let int_vec_sort: Arc<VecSort> = egraph
.get_sort_by(|s: &Arc<VecSort>| s.element_name() == i64_sort.name())
.unwrap();
egraph.add_primitive(InnerProduct {
ele: i64_sort,
vec: int_vec_sort,
});
egraph
.parse_and_run_program(
"
(let a (vec-of 1 2 3 4 5 6))
(let b (vec-of 6 5 4 3 2 1))
(check (= (inner-product a b) 56))
",
)
.unwrap();
}
A caveat here is that for container sorts, you need to first instantiate the container sort before introducing the primitive, which needs to get a handle on the sort.
Right now it is impossible to register a primitive without registering a user-defined sort.
This PR allows users to 1) directly register primitives and 2) get a handle of registered sorts (for primitive definition).
This is a breaking change since it changes the signature of
get_sort
inTypeInfo
(cc @saulshanabrook ).