Skip to content

Commit

Permalink
Fighting in vain with z3 params
Browse files Browse the repository at this point in the history
  • Loading branch information
ole-thoeb committed Jan 11, 2025
1 parent 971fdfd commit 669d70c
Show file tree
Hide file tree
Showing 3 changed files with 175 additions and 41 deletions.
2 changes: 1 addition & 1 deletion src/driver.rs
Original file line number Diff line number Diff line change
Expand Up @@ -705,7 +705,7 @@ impl<'ctx> SmtVcUnit<'ctx> {
});
}

let mut slice_solver = SliceSolver::new(slice_vars.clone(), translate, prover);
let slice_solver = SliceSolver::new(slice_vars.clone(), translate, prover);
let failing_slice_options = SliceSolveOptions {
globally_optimal: !options.slice_options.slice_error_first,
continue_on_unknown: false,
Expand Down
202 changes: 174 additions & 28 deletions z3rro/src/prover.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,13 @@
//! Not a SAT solver, but a prover. There's a difference.
use std::collections::HashMap;
use std::{fmt::Display, time::Duration};

use z3::{
ast::{forall_const, Ast, Bool, Dynamic},
Context, Model, Params, SatResult, Solver,
Context, Model, Params, SatResult, Solver, Symbol,
};

use crate::{
model::InstrumentedModel,
smtlib::Smtlib,
util::{set_solver_timeout, ReasonUnknown},
};
use crate::{model::InstrumentedModel, smtlib::Smtlib, util::ReasonUnknown};

/// The result of a prove query.
#[derive(Debug)]
Expand All @@ -33,6 +29,57 @@ impl Display for ProveResult<'_> {
}
}

#[derive(Debug, Clone, PartialEq)]
pub enum ParamValue {
Bool(bool),
Symbol(Symbol),
U32(u32),
F64(f64),
}

impl ParamValue {
fn set_to_params(&self, k: impl Into<Symbol>, params: &mut Params) {
match self {
ParamValue::Bool(b) => params.set_bool(k, *b),
ParamValue::Symbol(s) => params.set_symbol(k, s.clone()),
ParamValue::U32(u) => params.set_u32(k, *u),
ParamValue::F64(f) => params.set_f64(k, *f),
}
}
}

impl From<bool> for ParamValue {
fn from(v: bool) -> Self {
ParamValue::Bool(v)
}
}

impl From<u32> for ParamValue {
fn from(v: u32) -> Self {
ParamValue::U32(v)
}
}

impl From<f64> for ParamValue {
fn from(v: f64) -> Self {
ParamValue::F64(v)
}
}

impl From<Symbol> for ParamValue {
fn from(v: Symbol) -> Self {
ParamValue::Symbol(v)
}
}

fn to_params<'ctx>(ctx: &'ctx Context, param_map: &HashMap<String, ParamValue>) -> Params<'ctx> {
let mut parms = Params::new(ctx);
for (k, v) in param_map.iter() {
v.set_to_params(&**k, &mut parms);
}
parms
}

/// A prover wraps a SAT solver, but it's used to prove validity of formulas.
/// It's a bit of a more explicit API to distinguish between assumptions for a
/// proof ([`Prover::add_assumption`]) and provables ([`Prover::add_provable`]).
Expand All @@ -45,42 +92,76 @@ impl Display for ProveResult<'_> {
///
/// In contrast to [`z3::Solver`], the [`Prover`] requires exclusive ownership
/// to do any modifications of the solver.
#[derive(Debug, Clone)]
#[derive(Debug)]
pub struct Prover<'ctx> {
/// The underlying solver.
solver: Solver<'ctx>,
/// We are tracking the params ourselves
params: HashMap<String, ParamValue>,
/// Number of times push was called minus number of times pop was called.
level: usize,
/// The minimum level where an assertion was added to the solver.
min_level_with_provables: Option<usize>,
}

impl<'ctx> Clone for Prover<'ctx> {
fn clone(&self) -> Self {
println!("cloning prover");
let solver = self.solver.clone();
// Solver::clone does not copy the params.
// Therefore, we track them separately and copy them here.
solver.set_params(&to_params(solver.get_context(), &self.params));

Prover {
solver,
params: self.params.clone(),
level: self.level,
min_level_with_provables: self.min_level_with_provables,
}
}
}

impl<'ctx> Prover<'ctx> {
/// Create a new prover with the given [`Context`].
pub fn new(ctx: &'ctx Context) -> Self {
let solver = Solver::new(ctx);
solver.set_params(&default_params(ctx));
Prover {
solver,
let mut prover = Prover {
solver: Solver::new(ctx),
params: HashMap::default(),
level: 0,
min_level_with_provables: None,
}
};
// default params
prover.set_param("smt.qi.eager_threshold", 100.0);
prover.set_param("smt.qi.lazy_threshold", 1000.0);
prover.set_param("auto-config", false);
prover
}

/// Set a solver timeout with millisecond precision.
///
/// Panics if the duration is not representable as a 32-bit unsigned integer.
pub fn set_timeout(&mut self, duration: Duration) {
set_solver_timeout(&self.solver, duration);
self.set_param(
"timeout",
ParamValue::U32(duration.as_millis().try_into().unwrap()),
);
}

pub fn enforce_ematching(&mut self) {
let mut params = Params::new(self.solver.get_context());
// params.set_bool("auto-config", false);
params.set_bool("smt.mbqi", false);
self.solver.set_params(&params);
self.set_param("smt.mbqi", false);
}

pub fn seed(&mut self, seed: u32) {
self.set_param("smt.random_seed", seed);
}

pub fn set_param(&mut self, k: impl Into<String>, v: impl Into<ParamValue>) {
let key = k.into();
println!("set_param {}", key);
let value = v.into();
let mut params = Params::new(self.solver.get_context());
params.set_u32("smt.random_seed", seed);
value.set_to_params(&*key, &mut params);
self.params.insert(key, value);
self.solver.set_params(&params);
}

Expand Down Expand Up @@ -155,6 +236,7 @@ impl<'ctx> Prover<'ctx> {

/// See [`Solver::pop`].
pub fn pop(&mut self) {
println!("solver pop");
self.solver.pop(1);
self.level = self.level.checked_sub(1).expect("cannot pop level 0");
if let Some(prev_min_level) = self.min_level_with_provables {
Expand Down Expand Up @@ -205,18 +287,13 @@ impl<'ctx> Prover<'ctx> {
}
}

fn default_params<'ctx>(ctx: &'ctx Context) -> Params<'ctx> {
let mut params = Params::new(ctx);
params.set_f64("smt.qi.eager_threshold", 100.0);
params.set_f64("smt.qi.lazy_threshold", 1000.0);
params
}

#[cfg(test)]
mod test {
use z3::{ast::Bool, Config, Context, SatResult};

use super::{ProveResult, Prover};
use crate::scope::{SmtFresh, SmtScope};
use crate::{Fuel, FuelFactory, SmtBranch, SmtEq};
use z3::ast::{forall_const, Ast, Int};
use z3::{ast::Bool, Config, Context, FuncDecl, Params, Pattern, SatResult, Solver, Sort};

#[test]
fn test_prover() {
Expand All @@ -234,4 +311,73 @@ mod test {
assert!(matches!(prover.check_proof(), ProveResult::Proof));
assert_eq!(prover.check_sat(), SatResult::Sat);
}

// Tests that disabling mbqi works. For that we use a limited version of the faculty function
// (fac) and try to prove that fac(n) != 0. This is not for the smt solver.
// - MBQI active -> the solver creates new terms and hangs.
// - MBQI disabled -> The solver options are quickly exhausted, and it returns Unknown.
#[test]
fn enforce_ematching() {
fn fac<'ctx>(decl: &FuncDecl<'ctx>, fuel: &Fuel<'ctx>, n: &Int<'ctx>) -> Int<'ctx> {
decl.apply(&[&fuel.as_dynamic() as &dyn Ast, n as &dyn Ast])
.as_int()
.unwrap()
.clone()
}

let mut config = Config::new();
// config.set_bool_param_value("proof", true);
// config.set_bool_param_value("trace", true);
let ctx = Context::new(&config);

let mut solver = Solver::new(&ctx);

let mut scope = SmtScope::new();
let fuel_factory = FuelFactory::new(&ctx);

let int = Sort::int(&ctx);
let zero = Int::from_i64(&ctx, 0);
let one = Int::from_i64(&ctx, 1);
let two = Int::from_i64(&ctx, 2);
let fuel1 = Fuel::new(fuel_factory.clone(), 1);
let fac_decl = FuncDecl::new(&ctx, "fac", &[fuel_factory.sort(), &int], &int);

let fuel = Fuel::fresh(&fuel_factory, &mut scope, "fuel");
let n = Int::fresh(&&ctx, &mut scope, "n");
let app = fac(&fac_decl, &Fuel::succ(fuel.clone()), &n);

// forall fuel: Fuel, n: Int @trigger(fac(S(fuel), n)). fac(S(fuel), n) == ite(n < 2, 1, n * fac(fuel, n-1))
solver.assert(&forall_const(
&ctx,
&[&fuel.as_dynamic() as &dyn Ast, &n as &dyn Ast],
&[&Pattern::new(&ctx, &[&app as &dyn Ast])],
&app.smt_eq(&Int::branch(
&n.lt(&two),
&one,
&(&n * fac(&fac_decl, &fuel, &(&n - 1i64))),
)),
));
// forall fuel: Fuel, n: Int @trigger(fac(S(fuel), n)). fac(S(fuel), n) == fac(fuel, n)
solver.assert(&forall_const(
&ctx,
&[&fuel.as_dynamic() as &dyn Ast, &n as &dyn Ast],
&[&Pattern::new(&ctx, &[&app as &dyn Ast])],
&app.smt_eq(&fac(&fac_decl, &fuel, &n)),
));
let n2 = Int::fresh(&&ctx, &mut scope, "n2");
// fac(n2) != 0
solver.assert(&fac(&fac_decl, &fuel1, &n2).smt_eq(&zero).not().not());

// disabling mbqi
let mut params = Params::new(&ctx);
params.set_bool("smt.mbqi", false);
params.set_bool("auto-config", false);
solver.set_params(&params);

// Uncommenting will make the test succeed
// solver.push();
// solver.pop(1);

assert_eq!(solver.check(), SatResult::Unknown);
}
}
12 changes: 0 additions & 12 deletions z3rro/src/util.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,10 @@ use std::{
collections::HashMap,
fmt::{Display, Formatter, Write},
str::FromStr,
time::Duration,
};

use num::{BigInt, BigRational, Integer, Signed, Zero};

use z3::{Params, Solver};

/// Build a conjunction of Boolean expressions.
macro_rules! z3_and {
($first:expr, $( $x:expr, )*) => {
Expand Down Expand Up @@ -146,15 +143,6 @@ impl Display for ReasonUnknown {
}
}

/// Set a solver timeout with millisecond precision.
///
/// Panics if the duration is not representable as a 32-bit unsigned integer.
pub fn set_solver_timeout(solver: &Solver, duration: Duration) {
let mut params = Params::new(solver.get_context());
params.set_u32("timeout", duration.as_millis().try_into().unwrap());
solver.set_params(&params);
}

/// Pretty-printing wrapper type for [`BigRational`] values. This type's
/// [`Display`] instance will format this value exactly as a decimal. If the
/// rational is not a terminating fraction, the repeating fraction will be
Expand Down

0 comments on commit 669d70c

Please sign in to comment.