egraphs-good / egglog

egraphs + datalog!
https://egraphs-good.github.io/egglog/
MIT License
418 stars 46 forks source link

Expose interface to add primitives #280

Closed yihozhang closed 10 months ago

yihozhang commented 10 months ago

Right now it is impossible to register a primitive without registering a user-defined sort.

This PR allows users to 1) directly register primitives and 2) get a handle of registered sorts (for primitive definition).

This is a breaking change since it changes the signature of get_sort in TypeInfo (cc @saulshanabrook ).

yihozhang commented 10 months ago

@oflatt This is a great suggestion. I tweaked the interface a little bit and this allows us to introduce a new primitive as follows (see src/lib.rs):


    struct InnerProduct {
        ele: Arc<I64Sort>,
        vec: Arc<VecSort>,
    }

    impl PrimitiveLike for InnerProduct {
        fn name(&self) -> symbol_table::GlobalSymbol {
            "inner-product".into()
        }

        fn get_type_constraints(&self) -> Box<dyn crate::constraint::TypeConstraint> {
            SimpleTypeConstraint::new(
                self.name(),
                vec![self.vec.clone(), self.vec.clone(), self.ele.clone()],
            )
            .into_box()
        }

        fn apply(&self, values: &[crate::Value], _egraph: &EGraph) -> Option<crate::Value> {
            let mut sum = 0;
            let vec1 = Vec::<Value>::load(&self.vec, &values[0]);
            let vec2 = Vec::<Value>::load(&self.vec, &values[1]);
            assert_eq!(vec1.len(), vec2.len());
            for (a, b) in vec1.iter().zip(vec2.iter()) {
                let a = i64::load(&self.ele, a);
                let b = i64::load(&self.ele, b);
                sum += a * b;
            }
            sum.store(&self.ele)
        }
    }

    #[test]
    fn test_user_defined_primitive() {
        let mut egraph = EGraph::default();
        egraph
            .parse_and_run_program(
                "
                (sort IntVec (Vec i64))
            ",
            )
            .unwrap();
        let i64_sort: Arc<I64Sort> = egraph.get_sort().unwrap();
        let int_vec_sort: Arc<VecSort> = egraph
            .get_sort_by(|s: &Arc<VecSort>| s.element_name() == i64_sort.name())
            .unwrap();
        egraph.add_primitive(InnerProduct {
            ele: i64_sort,
            vec: int_vec_sort,
        });
        egraph
            .parse_and_run_program(
                "
                (let a (vec-of 1 2 3 4 5 6))
                (let b (vec-of 6 5 4 3 2 1))
                (check (= (inner-product a b) 56))
            ",
            )
            .unwrap();
    }

A caveat here is that for container sorts, you need to first instantiate the container sort before introducing the primitive, which needs to get a handle on the sort.