KardinalAI / cp_sat

Google CP-SAT solver Rust bindings
Apache License 2.0
18 stars 5 forks source link

Add solution handler #31

Open cyconer opened 11 months ago

cyconer commented 11 months ago

Partial of #9

This adds a solution handler that is called on newly encountered solutions (or improvements in optimization). The search can not yet be stopped (more details below).

Parallelism

The reference implementation https://github.com/google/or-tools/blob/stable/ortools/sat/docs/solver.md?plain=1#L472 states

Please note that it does not work in parallel (i. e. parameter num_search_workers > 1).

As mentioned in the ticket #9, it might be desirable to use Fn instead of FnMut. As I interpret the above that the search does not work in parallel anyway, it does not seem to matter whether we use Fn or FnMut. The implementation right now uses FnMut, but we can easily change that.

ControlFlow

Now it gets weird. I have played around with this based on the reference implementation, but something very strange seems to happen with the C++ part.

Diff of potential approach for Control Flow ```diff diff --git a/src/builder.rs b/src/builder.rs index 2fb392b..17d2a18 100644 --- a/src/builder.rs +++ b/src/builder.rs @@ -786,6 +786,7 @@ impl CpModelBuilder { /// # use std::rc::Rc; /// # use cp_sat::builder::CpModelBuilder; /// # use cp_sat::proto::{SatParameters, CpSolverResponse}; + /// # use std::ops::ControlFlow; /// let mut model = CpModelBuilder::default(); /// // linear constraint will only allow a = 2, a = 3 and a = 4 /// let a = model.new_int_var([(2, 7)]); @@ -793,10 +794,11 @@ impl CpModelBuilder { /// let mut params = SatParameters::default(); /// params.enumerate_all_solutions = Some(true); /// - /// let memory = Rc::new(RefCell::new(Vec::new())); + /// let memory: Rc>> = Rc::new(RefCell::new(Vec::new())); /// let memory2 = memory.clone(); /// let handler = move |response: CpSolverResponse| { /// memory2.borrow_mut().push(response); + /// ControlFlow::Continue(()) /// }; /// /// let _response = model.solve_with_parameters_and_handler(¶ms, handler); @@ -805,7 +807,7 @@ impl CpModelBuilder { pub fn solve_with_parameters_and_handler( &self, params: &proto::SatParameters, - handler: impl FnMut(proto::CpSolverResponse) + 'static, + handler: impl FnMut(proto::CpSolverResponse) -> std::ops::ControlFlow<()> + 'static, ) -> proto::CpSolverResponse { ffi::solve_with_parameters_and_handler(self.proto(), params, Box::new(handler)) } diff --git a/src/cp_sat_wrapper.cpp b/src/cp_sat_wrapper.cpp index c2a1126..eb92d38 100644 --- a/src/cp_sat_wrapper.cpp +++ b/src/cp_sat_wrapper.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace sat = operations_research::sat; @@ -58,8 +59,10 @@ cp_sat_wrapper_solve( * - serialized buffer of a CpSolverResponse * - length of the buffer * - additional data passed from the outside + * + * Returns true if the search should be aborted. */ -typedef void (*solution_handler)(unsigned char*, size_t, void*); +typedef bool (*solution_handler)(unsigned char*, size_t, void*); /** * Similar to cp_sat_wrapper_solve_with_parameters, but with a callback function @@ -89,6 +92,10 @@ cp_sat_wrapper_solve_with_parameters_and_handler( extra_model.Add(sat::NewSatParameters(params)); + // Atomic Boolean that will be periodically checked by the limit. + std::atomic stopped(false); + extra_model.GetOrCreate()->RegisterExternalBooleanAsLimit(&stopped); + // local function that serializes the CpSolverResponse for the provided solution handler auto wrapped_handler = [&](const operations_research::sat::CpSolverResponse& curr_response) { // serialize CpSolverResponse @@ -97,7 +104,10 @@ cp_sat_wrapper_solve_with_parameters_and_handler( bool curr_res = curr_response.SerializeToArray(response_buf, response_size); assert(curr_res); - handler(response_buf, response_size, handler_data); + bool abort = handler(response_buf, response_size, handler_data); + if (abort) { + stopped = true; + } }; extra_model.Add(sat::NewFeasibleSolutionObserver(wrapped_handler)); diff --git a/src/ffi.rs b/src/ffi.rs index 7284c6c..27c07c8 100644 --- a/src/ffi.rs +++ b/src/ffi.rs @@ -3,6 +3,7 @@ use libc::c_char; use prost::Message; use std::ffi::CStr; use std::ffi::c_void; +use std::ops::ControlFlow; extern "C" { fn cp_sat_wrapper_solve( @@ -22,7 +23,7 @@ extern "C" { model_size: usize, params_buf: *const u8, params_size: usize, - handler_caller: extern "C" fn(*const u8, usize, *mut c_void), + handler_caller: extern "C" fn(*const u8, usize, *mut c_void) -> bool, handler: *mut c_void, out_size: &mut usize, ) -> *mut u8; @@ -83,7 +84,8 @@ pub fn solve_with_parameters( } /// User provided solution handler that is called with feasible solutions. -pub type SolutionHandler = Box; +/// The control flow can be used to abort the search. +pub type SolutionHandler = Box ControlFlow<()>>; /// Solves the given [CpModelProto][crate::proto::CpModelProto] with /// the given parameters, @@ -129,16 +131,23 @@ pub fn solve_with_parameters_and_handler( /// - `response_buf` and `response_size`: buffer and size of a [proto::CpSolverResponse] /// - `handler`: a user provided solution handler [SolutionHandler] that accepts a /// [proto::CpSolverResponse] -extern "C" fn solution_handler_caller(response_buf: *const u8, response_size: usize, handler: *mut c_void) { +/// +/// Returns `true` if the search should be aborted. +extern "C" fn solution_handler_caller(response_buf: *const u8, response_size: usize, handler: *mut c_void) -> bool { let response_slice = unsafe { std::slice::from_raw_parts(response_buf, response_size) }; let response = proto::CpSolverResponse::decode(response_slice).unwrap(); unsafe { libc::free(response_buf as _) }; - unsafe { + let control_flow = unsafe { let tmp = handler as *mut SolutionHandler; - (*tmp)(response); + (*tmp)(response) + }; + + match control_flow { + ControlFlow::Continue(_) => false, + ControlFlow::Break(_) => true, } } diff --git a/tests/solution_handler.rs b/tests/solution_handler.rs index 959541f..3caf620 100644 --- a/tests/solution_handler.rs +++ b/tests/solution_handler.rs @@ -14,10 +14,11 @@ fn enumeration_solution_handler() { let mut params = SatParameters::default(); params.enumerate_all_solutions = Some(true); - let memory = Rc::new(RefCell::new(Vec::new())); + let memory: Rc>> = Rc::new(RefCell::new(Vec::new())); let memory2 = memory.clone(); let handler = move |response: CpSolverResponse| { memory2.borrow_mut().push(response); + std::ops::ControlFlow::Continue(()) }; let _response = model.solve_with_parameters_and_handler(¶ms, handler); @@ -45,10 +46,11 @@ fn optimization_solution_handler() { let mut params = SatParameters::default(); params.enumerate_all_solutions = Some(true); - let memory = Rc::new(RefCell::new(Vec::new())); + let memory: Rc>> = Rc::new(RefCell::new(Vec::new())); let memory2 = memory.clone(); let handler = move |response: CpSolverResponse| { memory2.borrow_mut().push(response); + std::ops::ControlFlow::Continue(()) }; let response = model.solve_with_parameters_and_handler(¶ms, handler); @@ -61,3 +63,31 @@ fn optimization_solution_handler() { // improvement. assert!(memory.borrow().len() >= 1); } + +/// It should be possible to stop the search from the callback. +#[test] +fn stop_solution_handler() { + let mut model = CpModelBuilder::default(); + // linear constraint will only allow a = 2, a = 3 and a = 4 + let a = model.new_int_var([(2, 7)]); + model.add_linear_constraint([(3, a)], [(0, 13)]); + let mut params = SatParameters::default(); + params.enumerate_all_solutions = Some(true); + + let memory: Rc>> = Rc::new(RefCell::new(Vec::new())); + let memory2 = memory.clone(); + let handler = move |response: CpSolverResponse| { + memory2.borrow_mut().push(response); + + if memory2.borrow().len() < 2 { + std::ops::ControlFlow::Continue(()) + } else { + std::ops::ControlFlow::Break(()) + } + }; + + let _response = model.solve_with_parameters_and_handler(¶ms, handler); + + // Instead of the 3 feasible solution the search was aborted after 2. + assert_eq!(2, memory.borrow().len()); +} ```

Problem: completely unrelated code starts to exibit SIGSEGV when extra_model.GetOrCreate<operations_research::TimeLimit>() is included.

Minimal failing example that completely baffles me ```diff diff --git a/src/cp_sat_wrapper.cpp b/src/cp_sat_wrapper.cpp index c2a1126..94c6661 100644 --- a/src/cp_sat_wrapper.cpp +++ b/src/cp_sat_wrapper.cpp @@ -2,6 +2,7 @@ #include #include +#include namespace sat = operations_research::sat; @@ -89,6 +90,12 @@ cp_sat_wrapper_solve_with_parameters_and_handler( extra_model.Add(sat::NewSatParameters(params)); + bool this_is_never_reached = false; + if (this_is_never_reached) { + // Including this line leads to SIGSEGV of e.g. cp_sat_wrapper_solve??? + extra_model.GetOrCreate(); + } + // local function that serializes the CpSolverResponse for the provided solution handler auto wrapped_handler = [&](const operations_research::sat::CpSolverResponse& curr_response) { // serialize CpSolverResponse ```

E.g. the tests/bool_cst.rs starts to exhibit a SIGSEGV which does not use the new cp_sat_wrapper_solve_with_parameters_and_handler but the unchanged cp_sat_wrapper_solve. It seems to fail in the absl library in absl/container/internal/raw_hash_set.h in line 1562. It does not fail when using e.g. GetOrCreate<bool>(), but fails using TimeLimit as the generic.

I have no idea what is going on, how unrelated code can fail just because a (never used) function call is included in another function. Maybe some C++-magic of generics that affects the Model globally... Any ideas are appreciated.