prove-rs / z3.rs

Rust bindings for the Z3 solver.
347 stars 110 forks source link

Iterator for a solver #221

Open onthestairs opened 1 year ago

onthestairs commented 1 year ago

Hi, firstly, thank you very much for this libary - it's great!

I'm not sure if this is a generic rust question, or something specific to this library. I am trying to create an 'iterator' of solutions from z3, so that I can access the solutions 'lazily' and in a composable way (For example taking only the first n solutions with .take(n)). However, I can't figure out the ergonomics of it, and I'm not sure if it will even be possible to make the lifetimes work.

Here is a version of what I would like to do which doesn't use iterators, and just returns a Vec of all the solutions:

use z3::ast;
use z3::ast::Ast;
use z3::Config;
use z3::Context;
use z3::SatResult;
use z3::Solver;

fn solver(target: i64) -> Vec<(i64, i64)> {
    let cfg = Config::new();
    let ctx = Context::new(&cfg);
    let solver = Solver::new(&ctx);

    let x = ast::Int::new_const(&ctx, "x");
    let y = ast::Int::new_const(&ctx, "y");
    let n = ast::Int::from_i64(&ctx, target);
    let equals = ast::Int::add(&ctx, &[&x, &y])._eq(&n);
    solver.assert(&x.lt(&n));
    solver.assert(&y.lt(&n));
    solver.assert(&equals);

    let mut solutions = vec![];
    while solver.check() == SatResult::Sat {
        let model = solver.get_model().unwrap();
        let x_val = model.eval(&x, true).unwrap();
        let y_val = model.eval(&y, true).unwrap();
        let solution = (x_val.as_i64().unwrap(), y_val.as_i64().unwrap());
        solutions.push(solution);

        let is_the_solution = ast::Bool::and(&ctx, &[&x._eq(&x_val), &y._eq(&y_val)]);
        solver.assert(&is_the_solution.not());
    }
    return solutions;
}

fn main() {
    for (x, y) in solver(10) {
        println!("x={}, y={}", x, y)
    }
}

Here is an attempt to figure out what it 'might' look like. But I cannot arrange it to satisfy the lifetime checker:

use z3::ast;
use z3::ast::Ast;
use z3::Config;
use z3::Context;
use z3::SatResult;
use z3::Solver;

struct AddSolver<'a> {
    ctx: Context,
    solver: Solver<'a>,
    x: ast::Int<'a>,
    y: ast::Int<'a>,
}

impl<'a> Iterator for AddSolver<'a> {
    type Item = (i64, i64);

    fn next(&mut self) -> Option<Self::Item> {
        if self.solver.check() == SatResult::Sat {
            let model = self.solver.get_model().unwrap();
            let x_val = model.eval(&self.x, true).unwrap();
            let y_val = model.eval(&self.y, true).unwrap();
            let solution = (x_val.as_i64().unwrap(), y_val.as_i64().unwrap());
            // disallow this solution
            let is_the_solution =
                ast::Bool::and(&self.ctx, &[&self.x._eq(&x_val), &self.y._eq(&y_val)]);
            self.solver.assert(&is_the_solution.not());
            // return the solution
            return Some(solution);
        } else {
            // no solution, end of iteration
            return None;
        }
    }
}

fn solver<'a>(target: i64) -> AddSolver<'a> {
    let cfg = Config::new();
    let ctx = Context::new(&cfg);
    let solver = Solver::new(&ctx);

    let x = ast::Int::new_const(&ctx, "x");
    let y = ast::Int::new_const(&ctx, "y");
    let n = ast::Int::from_i64(&ctx, target);
    let equals = ast::Int::add(&ctx, &[&x, &y])._eq(&n);
    solver.assert(&x.lt(&n));
    solver.assert(&y.lt(&n));
    solver.assert(&equals);

    return AddSolver { solver, ctx, x, y };
}

fn main() {
    for (x, y) in solver(10).take(5) {
        println!("x={}, y={}", x, y)
    }
}

Is this something that could be done using this library? If not, is there something we could add to make it possible? Thank you!

Pat-Lafon commented 1 year ago

I saw this and thought it was interesting so I gave a crack at it. Instead of creating an iterator over solutions, I created an iterator over models.

Two things that made this tricky is that ~the Ast is implemented as a trait and~ I couldn't find a way to query for all of the constants in the context. ~The former means that you can't easily hold onto different types in a sized way so I created my own enum for that.~(I didn't know about Dynamic) For the latter, I'm following a stackoverflow solution on this topic which relies on creating a large or over all of your constants and saying that one of these can't be equal to it's current model value. Other bindings like Ocaml's allow you to get a list of "function declarations of the constants" https://z3prover.github.io/api/html/ml/Z3.Model.html but I don't think that is exposed here so I have to store a separate list.

The following is what I came up with:

struct Z3Iterator<'ctx> {
    solver: Solver<'ctx>,
    vars: Vec<Dynamic<'ctx>>,
}

impl<'ctx> Iterator for Z3Iterator<'ctx> {
    type Item = Model<'ctx>;

    fn next(&mut self) -> Option<Self::Item> {
        if let SatResult::Sat = self.solver.check() {
            let model = self.solver.get_model()?;
            let negating_equations = self
                .vars
                .iter()
                .map(|ast| ast._eq(&model.eval(ast, true).unwrap()).not())
                .collect::<Vec<_>>();
            let invalidate_model = Bool::or(
                self.solver.get_context(),
                &negating_equations.iter().collect::<Vec<_>>(),
            );
            self.solver.assert(&invalidate_model);
            Some(model)
        } else {
            None
        }
    }
}

fn solver(target: i64) -> Vec<(i64, i64)> {
    let cfg = Config::new();
    let ctx = Context::new(&cfg);
    let solver = Solver::new(&ctx);

    let x = ast::Int::new_const(&ctx, "x");
    let y = ast::Int::new_const(&ctx, "y");
    let n = ast::Int::from_i64(&ctx, target);
    let equals = ast::Int::add(&ctx, &[&x, &y])._eq(&n);
    solver.assert(&x.lt(&n));
    solver.assert(&y.lt(&n));
    solver.assert(&equals);

    let iterator = Z3Iterator {
        solver,
        vars: vec![Dynamic::from_ast(&x), Dynamic::from_ast(&y)],
    };

    let solutions = iterator
        .into_iter()
        .map(|model| {
            let x_val = model.eval(&x, true).unwrap();
            let y_val = model.eval(&y, true).unwrap();
            (x_val.as_i64().unwrap(), y_val.as_i64().unwrap())
        })
        .collect();

    return solutions;
}

fn main() {
    for (x, y) in solver(10) {
        println!("x={}, y={}", x, y)
    }
}