uwplse / ruler

Rewrite Rule Inference Using Equality Saturation
https://dl.acm.org/doi/10.1145/3485496
MIT License
112 stars 8 forks source link

Cannot enumerate terms for larger depths #226

Open boronhub opened 5 months ago

boronhub commented 5 months ago

Hello, I am trying to enumerate these terms using Enumo.


// ruler/tests/recipes/vec.rs

use ruler::{
    enumo::{Filter, Metric, Ruleset, Workload},
    recipe_utils::{recursive_rules, run_workload, Lang,iter_metric},
    Limits,
};

use crate::Pred;
fn iter_grammar(n: usize) -> Workload {
    let lang = Workload::new([
        "EXPR",
    ]);
    let EXPR : &[&str] = &[
        "(vec_d EXPR)",
        "(vec_s EXPR)",
        "vals"
    ];
    iter_metric(lang, "EXPR", Metric::Depth, n)
        .plug("EXPR", &EXPR.into())
        .plug("vals", &Workload::new(&["val_0", "val_1", "val_2"]))

}

pub fn vec_rules() -> Ruleset<Pred> {
    println!("Generating vec rules!");
    let mut all = Ruleset::default();
    let canon = iter_grammar(3);
    let get_rules = run_workload(
        canon,
        all.clone(),
        Limits::synthesis(),
        Limits::minimize(),
        true
    );

    all.extend(get_rules);
    all
}

// ruler/tests/vec.rs

use num::{ToPrimitive, Zero};
use ruler::*;
use z3::ast::Ast;

type Constant = i64;
egg::define_language!{
    pub enum Pred {
        BVLit(Constant),
        "vec_d" = vec_d(Id),
        "vec_s" = vec_s(Id),
        Var(Symbol),
    }
}

impl SynthLanguage for Pred {
    type Constant = Constant;

    fn eval<'a, F>(&'a self, cvec_len: usize, mut get_cvec: F) -> CVec<Self>
    where
        F: FnMut(&'a Id) -> &'a CVec<Self>,
        {
            match self {
                Pred::BVLit(v0) => {
                    println!("Interpreting BVlit");
                    vec![Some(v0.clone()); cvec_len]
                }
                Pred::vec_d(x) => {
                    println!("interpreting vec_d");
                    map!(get_cvec, x => 
                        Some(x.clone())
                        ) 
                }
                Pred::vec_s(x) => {
                    println!("interpreting vec_s");
                    map!(get_cvec, x => 
                        Some(x.clone())

                        ) 
                }
                Pred::Var(_) => vec![],
            }
        }

    fn initialize_vars(egraph: &mut EGraph<Self, SynthAnalysis>, vars: &[String]) {
        let consts = vec![
            Some(1.to_i64().unwrap()),
        ];
        let cvecs = self_product(&consts, vars.len());

        egraph.analysis.cvec_len = cvecs[0].len();

        for (i, v) in vars.iter().enumerate() {
            let id = egraph.add(Pred::Var(Symbol::from(v.clone())));
            let cvec = cvecs[i].clone();
            egraph[id].data.cvec = cvec;
        }
    }

    fn to_var(&self) -> Option<Symbol> {
        if let Pred::Var(sym) = self {
            Some(*sym)
        } else {
            None
        }
    }

    fn mk_var(sym: Symbol) -> Self {
        Pred::Var(sym)
    }

    fn is_constant(&self) -> bool {
        matches!(self, Pred::BVLit(_))
    }

    fn mk_constant(c: Self::Constant, _egraph: &mut EGraph<Self, SynthAnalysis>) -> Self {
        Pred::BVLit(c)
    }

    fn validate(lhs: &Pattern<Self>, rhs: &Pattern<Self>) -> ValidationResult {
        let lexpr = egg_to_external_prog(Self::instantiate(lhs).as_ref());
        let rexpr = egg_to_external_prog(Self::instantiate(rhs).as_ref());
        println!("LEFT EXPRESSION");
        println!("{}",lexpr);
        println!("RIGHT EXPRESSION");
        println!("{}",rexpr);
        ValidationResult::Invalid
    }
}

fn egg_to_external_prog<'a>(expr: &[Pred]) ->  String {
    let mut buf: Vec<String> = vec![];
    for node in expr.as_ref().iter() {
        match node {
            Pred::BVLit(v0) => {
                //buf.push("(lits)".to_string())
                buf.push(format!("{}", v0))
            }
            Pred::Var(v0) => {
                //buf.push("(vars)".to_string())

                buf.push(format!("{}", v0))
            }
            Pred::vec_d(x) => {buf.push(format!("(vec_d_dsl {})", &buf[usize::from(*x)]))},
            Pred::vec_s(x) => {buf.push(format!("(vec_s_dsl {})", &buf[usize::from(*x)]))},
        }
    }
    buf.pop().unwrap()
}

#[cfg(test)]
#[path = "./recipes/vec.rs"]
mod vec;

mod test {
    use crate::vec::vec_rules;
    use crate::Pred;
    use std::time::{Duration, Instant};

    use ruler::{
        enumo::{Filter, Metric, Ruleset, Workload},
        logger,
        recipe_utils::{recursive_rules, run_workload, Lang},
        Limits,
    };

    #[test]
    fn run() {
        // Skip this test in github actions
        if std::env::var("CI").is_ok() && std::env::var("SKIP_RECIPES").is_ok() {
            return;
        }

        let start = Instant::now();
        // Runs the actual search
        let all_rules = vec_rules();
        let duration = start.elapsed();

    }
}

The validation function is returning Invalid for now since I need to call an external program containing the semantics to validate it instead of encoding it in Z3. However, on running the test, the terms don't enumerate to the specified depth.

Generating vec rules!
interpreting vec_d
interpreting vec_s
LEFT EXPRESSION
(vec_d_dsl a)
RIGHT EXPRESSION
a
LEFT EXPRESSION
a
RIGHT EXPRESSION
(vec_d_dsl a)
LEFT EXPRESSION
a
RIGHT EXPRESSION
(vec_d_dsl a)
LEFT EXPRESSION
(vec_s_dsl a)
RIGHT EXPRESSION
a
LEFT EXPRESSION
a
RIGHT EXPRESSION
(vec_s_dsl a)
LEFT EXPRESSION
a
RIGHT EXPRESSION
(vec_s_dsl a)
LEFT EXPRESSION
(vec_d_dsl a)
RIGHT EXPRESSION
(vec_s_dsl a)
LEFT EXPRESSION
(vec_s_dsl a)
RIGHT EXPRESSION
(vec_d_dsl a)
LEFT EXPRESSION
(vec_s_dsl a)
RIGHT EXPRESSION
(vec_d_dsl a)

We are trying to generate the rule (vec_d_dsl (vec_s_dsl a)) = a. Can you point to any reasons as to why that pattern is not being enumerated?

ajpal commented 5 months ago

Hi! Thanks for your question, I'm excited that you're trying to use Enumo :)

I think your issue is in how you're using iter_metric. iter_metric is semantically equivalent to repeatedly plugging into the same workload. Consider, for example:

#[test]
fn iter_metric_repeated_plug() {
    let lang = Workload::new(["(OP EXPR EXPR)", "(OP2 EXPR)", "VAL"]);
    let plugged1 = lang.clone().plug("EXPR", &lang).filter(Filter::MetricLt(Metric::Depth, 2));
    let itered1 = iter_metric(lang.clone(), "EXPR", Metric::Depth, 1);
    assert_eq!(plugged1.force(), itered1.force());

    let plugged2 = lang.clone().plug("EXPR", &plugged1).filter(Filter::MetricLt(Metric::Depth, 3));
    let itered2 = iter_metric(lang.clone(), "EXPR", Metric::Depth, 2);
    assert_eq!(plugged2.force(), itered2.force());

    let plugged3 = lang.clone().plug("EXPR", &plugged2).filter(Filter::MetricLt(Metric::Depth, 4));
    let itered3 = iter_metric(lang, "EXPR", Metric::Depth, 3);
    assert_eq!(plugged3.force(), itered3.force());
}

So in your code, you're doing iter_metric(Workload::new(["EXPR"]), "EXPR", Metric::Depth, n), but that's not going to actually do anything- the workload will still only contain EXPR at the end.

let lang = Workload::new(["EXPR"]);
let lang3 = iter_metric(lang, "EXPR", Metric::Depth, 3);
let expected = Workload::new(["EXPR"]);
assert_eq!(expected.force(), lang3.force());

And then you plug in exprs, giving you a workload containing (vec_d EXPR) (vec_s EXPR) vals

let exprs: &[&str] = &["(vec_d EXPR)", "(vec_s EXPR)", "vals"];
let plug_expr = lang3.plug("EXPR", &exprs.into());
let expected2: Workload = Workload::new(["(vec_d EXPR)", "(vec_s EXPR)","vals"]);
assert_eq!(expected2.force(), plug_expr.force());

And finally you plug in vals, giving you (vec_d EXPR) (vec_s EXPR) val_0 val_1 val_2

let plug_vals = plug_expr.plug("vals", &Workload::new(["val_0", "val_1", "val_2"]));
let expected3 = Workload::new(["(vec_d EXPR)", "(vec_s EXPR)", "val_0", "val_1", "val_2"]);
assert_eq!(plug_vals.force(), expected3.force());

Correct me if I'm wrong, but I think you're trying to enumerate these terms:

(vec_d (vec_d val_0))
(vec_d (vec_d val_1))
(vec_d (vec_d val_2))
(vec_d (vec_s val_0))
(vec_d (vec_s val_1))
(vec_d (vec_s val_2))
(vec_d val_0)
(vec_d val_1)
(vec_d val_2)
(vec_s (vec_d val_0))
(vec_s (vec_d val_1))
(vec_s (vec_d val_2))
(vec_s (vec_s val_0))
(vec_s (vec_s val_1))
(vec_s (vec_s val_2))
(vec_s val_0)
(vec_s val_1)
(vec_s val_2)
val_0
val_1
val_2

In that case, I think you'll want to make a workload that's something like this:

let lang = Workload::new(["(vec_d EXPR)", "(vec_s EXPR)", "VAL"]);
let depth3 = iter_metric(lang, "EXPR", Metric::Depth, 3)
                   .plug("VAL", &Workload::new(["val_0", "val_1", "val_2"]));

Please let me know if that helps, and don't hesitate to reach out with any other questions!

RafaeNoor commented 4 months ago

Hello @ajpal , Thanks for your reply. The above suggestion does fix the enumeration however the leaves of the enumerated expressions have the term a instead of the val_0, val_1, or val_2.

Replacing the enumeration with with:

let lang = Workload::new(["(vec_d EXPR)", "(vec_s EXPR)", "VAL"]);
let depth3 = iter_metric(lang, "EXPR", Metric::Depth, 3)
                   .plug("VAL", &Workload::new(["val_0", "val_1", "val_2"]));

produces the following enumerated terms:

...
LEFT EXPRESSION
(vec_s_dsl a)
RIGHT EXPRESSION
(vec_s_dsl (vec_d_dsl a))
LEFT EXPRESSION
(vec_s_dsl (vec_d_dsl a))
RIGHT EXPRESSION
(vec_s_dsl a)
...

instead of something like

LEFT EXPRESSION
(vec_s_dsl val_0)
RIGHT EXPRESSION
(vec_s_dsl (vec_d_dsl val_0))
LEFT EXPRESSION
(vec_s_dsl (vec_d_dsl val_0))
RIGHT EXPRESSION
(vec_s_dsl val_1)

Could you suggest how to rectify this?

RafaeNoor commented 4 months ago

Also another question:

Consider that the enumerated terms are of different types. For example, consider that we have a DSL of 32 bit integers and 64 bit integers and operations operating on 32-bit and 64 bit operations respectively. Once can convert to 64 bit integers using a sign-extension and the reverse using a truncation. Could you specify how we can enumerate expressions in these case? We want to enumerate all expressions up to depth 3 where the output type may be either a 64-bit value or a 32-bit value.

Your help will be greatly appreciated!

ajpal commented 4 months ago

Sorry, I missed this reply!!

I don't understand or repro the issue with a you describe above. Running the following:

    #[test]
    fn wkld_test() {
        let lang = Workload::new(["(vec_d EXPR)", "(vec_s EXPR)", "VAL"]);
        let depth3 = iter_metric(lang, "EXPR", Metric::Depth, 3)
                        .plug("VAL", &Workload::new(["val_0", "val_1", "val_2"]));
        depth3.pretty_print();
    }

I get

---- test::bar stdout ----
(vec_d (vec_d val_0 ) )
(vec_d (vec_d val_1 ) )
(vec_d (vec_d val_2 ) )
(vec_d (vec_s val_0 ) )
(vec_d (vec_s val_1 ) )
(vec_d (vec_s val_2 ) )
(vec_d val_0 )
(vec_d val_1 )
(vec_d val_2 )
(vec_s (vec_d val_0 ) )
(vec_s (vec_d val_1 ) )
(vec_s (vec_d val_2 ) )
(vec_s (vec_s val_0 ) )
(vec_s (vec_s val_1 ) )
(vec_s (vec_s val_2 ) )
(vec_s val_0 )
(vec_s val_1 )
(vec_s val_2 )
val_0
val_1
val_2

Which is, I think, as expected.

Can you clarify and/or send me a branch I can play around with?

For the second question (about enumerating different types), Workload terms aren't typed (they're just s-expressions). The semantics of what the term means is up to your implementation of eval. Can you clarify a little bit what kinds of terms you're trying to enumerate, and what the problems you're running into are?

RafaeNoor commented 4 months ago

Hello Anjali, Thanks for your reply. I will create the a fork of the issue I'm running into and share with you shortly here. Regarding the second question: Consider my grammar consists of the following scalar operations:

<32-bit-expr>: sign-extend-16-to-32 <16-bit-expr> sign-extend-8-to-32 <16-bit-expr> add-32 <32-bit-expr> mul-32 <32-bit-expr>

<16-bit-expr>: truncate-32-to-16 <32-bit-expr> sign-extend-8-to-32 <8-bit-expr> add-16 <16-bit-expr> mul-16 <16-bit-expr>

<8-bit-expr>: truncate-32-to-8 <32-bit-expr> truncate-16-to-8 <16-bit-expr> add-8 <8-bit-expr> mul-8 <8-bit-expr>

As you can see the expressions are mutually recursive, and so I'm unsure of how to express a workload such as this. Say we wanted to enumerate all expressions upto depth 3 where the final return type is 32-bit-expr. Having a flattened grammar with all of these terms inside can work however, many of those would be illegally typed. I can always fall back to that but would like to avoid enumerating those expressions to begin with. Is there a way in Enumo to express these constraints in the enumeration itself?