Skip to content

Setoid rewrite #742

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions src/ecCommands.ml
Original file line number Diff line number Diff line change
Expand Up @@ -682,6 +682,14 @@ and process_addrw scope (local, base, names) =
and process_reduction scope name =
EcScope.Reduction.add_reduction scope name

(* -------------------------------------------------------------------- *)
and process_relation (scope : EcScope.scope) (rel : prelation) =
EcScope.Setoid.add_relation scope rel

(* -------------------------------------------------------------------- *)
and process_morphism (scope : EcScope.scope) (mph : pmorphism) =
EcScope.Setoid.add_morphism scope mph

(* -------------------------------------------------------------------- *)
and process_hint scope hint =
EcScope.Auto.add_hint scope hint
Expand Down Expand Up @@ -783,6 +791,8 @@ and process (ld : Loader.loader) (scope : EcScope.scope) g =
| Goption opt -> `Fct (fun scope -> process_option scope opt)
| Gaddrw hint -> `Fct (fun scope -> process_addrw scope hint)
| Greduction red -> `Fct (fun scope -> process_reduction scope red)
| Grelation rel -> `Fct (fun scope -> process_relation scope rel)
| Gmorphism mph -> `Fct (fun scope -> process_morphism scope mph)
| Ghint hint -> `Fct (fun scope -> process_hint scope hint)
| GdumpWhy3 file -> `Fct (fun scope -> process_dump_why3 scope file)
with
Expand Down
12 changes: 6 additions & 6 deletions src/ecCoreGoal.ml
Original file line number Diff line number Diff line change
Expand Up @@ -508,7 +508,7 @@ module FApi = struct

(* ------------------------------------------------------------------ *)
let xmutate (tc : tcenv) (vx : 'a) (fp : form list) =
let (tc, hds) = List.map_fold (fun tc fp -> newgoal tc fp) tc fp in
let (tc, hds) = List.fold_left_map (fun tc fp -> newgoal tc fp) tc fp in
close tc (VExtern (vx, hds))

(* ------------------------------------------------------------------ *)
Expand All @@ -518,7 +518,7 @@ module FApi = struct
(* ------------------------------------------------------------------ *)
let xmutate_hyps (tc : tcenv) (vx : 'a) subgoals =
let (tc, hds) =
List.map_fold
List.fold_left_map
(fun tc (hyps, fp) -> newgoal tc ~hyps fp)
tc subgoals
in
Expand Down Expand Up @@ -564,11 +564,11 @@ module FApi = struct

(* ------------------------------------------------------------------ *)
let on_sub1i_goals (tt : int -> backward) (hds : handle list) (pe : proofenv) =
let do1 i pe hd =
let do1 pe i hd =
let tc = tt i (tcenv1_of_penv hd pe) in
assert (tc.tce_tcenv.tce_ctxt = []);
(tc_penv tc, tc_opened tc) in
List.mapi_fold do1 pe hds
List.fold_left_mapi do1 pe hds

(* ------------------------------------------------------------------ *)
let on_sub1_goals (tt : backward) (hds : handle list) (pe : proofenv) =
Expand All @@ -578,11 +578,11 @@ module FApi = struct
let on_sub1i_map_goals
(tt : int -> tcenv1 -> 'a * tcenv) (hds : handle list) (pe : proofenv)
=
let do1 i pe hd =
let do1 pe i hd =
let data, tc = tt i (tcenv1_of_penv hd pe) in
assert (tc.tce_tcenv.tce_ctxt = []);
(tc_penv tc, (tc_opened tc, data)) in
List.mapi_fold do1 pe hds
List.fold_left_mapi do1 pe hds

(* ------------------------------------------------------------------ *)
let on_sub1_map_goals
Expand Down
2 changes: 1 addition & 1 deletion src/ecCoreSubst.ml
Original file line number Diff line number Diff line change
Expand Up @@ -600,7 +600,7 @@ module Fsubst = struct

(* ------------------------------------------------------------------ *)
and add_bindings (s : f_subst) : bindings -> f_subst * bindings =
List.map_fold add_binding s
List.fold_left_map add_binding s

(* ------------------------------------------------------------------ *)
and add_mod_params (s : f_subst) (params : _) =
Expand Down
75 changes: 67 additions & 8 deletions src/ecEnv.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ module Mp = EcPath.Mp
module Sid = EcIdent.Sid
module Mid = EcIdent.Mid
module TC = EcTypeClass
module Sint = EcMaps.Sint
module Mint = EcMaps.Mint

(* -------------------------------------------------------------------- *)
Expand Down Expand Up @@ -183,6 +184,7 @@ type preenv = {
env_rwbase : Sp.t Mip.t;
env_atbase : atbase Msym.t;
env_redbase : mredinfo;
env_stdbase : setoid;
env_ntbase : ntbase Mop.t;
env_modlcs : Sid.t; (* declared modules *)
env_item : theory_item list; (* in reverse order *)
Expand Down Expand Up @@ -225,6 +227,13 @@ and atbase0 = path * [`Rigid | `Default]

and atbase = atbase0 list Mint.t

and setoid = setoid1 Mp.t

and setoid1 = {
spec : path;
morphisms : (path Mint.t) Mp.t;
}

(* -------------------------------------------------------------------- *)
type env = preenv

Expand Down Expand Up @@ -311,6 +320,7 @@ let empty gstate =
env_rwbase = Mip.empty;
env_atbase = Msym.empty;
env_redbase = Mrd.empty;
env_stdbase = Mp.empty;
env_ntbase = Mop.empty;
env_modlcs = Sid.empty;
env_item = [];
Expand Down Expand Up @@ -611,7 +621,7 @@ module MC = struct
let mc = lookup_mc qn env in
let objs = odfl [] (mc |> omap (fun mc -> MMsym.all x (proj mc))) in
let _, objs =
List.map_fold
List.fold_left_map
(fun ps ((p, _) as obj)->
if Sip.mem p ps
then (ps, None)
Expand Down Expand Up @@ -1019,7 +1029,7 @@ module MC = struct
in

let (mc, submcs) =
List.map_fold mc1_of_module
List.fold_left_map mc1_of_module
(empty_mc
(if p2 = None then Some me.me_params else None))
me.me_comps
Expand Down Expand Up @@ -1110,12 +1120,13 @@ module MC = struct
(add2mc _up_rwbase x (expath x) mc, None)

| Th_export _ | Th_addrw _ | Th_instance _
| Th_auto _ | Th_reduction _ ->
| Th_auto _ | Th_reduction _ | Th_relation _
| Th_morphism _ ->
(mc, None)
in

let (mc, submcs) =
List.map_fold mc1_of_theory (empty_mc None) cth.cth_items
List.fold_left_map mc1_of_theory (empty_mc None) cth.cth_items
in
((x, mc), List.rev_pmap identity submcs)

Expand Down Expand Up @@ -1582,6 +1593,35 @@ module Auto = struct
Msym.values env.env_atbase |> List.map flatten_db |> List.flatten
end

(* -------------------------------------------------------------------- *)
module Setoid = struct
type nonrec setoid1 = setoid1

let update_relation_db ((oppath, axpath) : path * path) (db : setoid) =
Mp.add oppath { spec = axpath; morphisms = Mp.empty; } db

let add_relation ((oppath, axpath) : path * path) (env : env) =
let item = mkitem import0 (Th_relation (oppath, axpath)) in
{ env with
env_stdbase = update_relation_db (oppath, axpath) env.env_stdbase;
env_item = item :: env.env_item; }

let get_relation (env : env) (oppath : path) : setoid1 option =
Mp.find_opt oppath env.env_stdbase

let update_morphism_db ((rel, op, ax, pos) : path * path * path * int) (db : setoid) =
Mp.change (fun db1 ->
Some { (oget db1) with morphisms =
Mp.change (fun m -> Some (Mint.add pos ax (odfl Mint.empty m))) op (oget db1).morphisms }
) rel db

let add_morphism ((rel, op, ax, pos) : path * path * path * int) (env : env) =
let item = mkitem import0 (Th_morphism (rel, op, ax, pos)) in
{ env with
env_stdbase = update_morphism_db (rel, op, ax, pos) env.env_stdbase;
env_item = item :: env.env_item; }
end

(* -------------------------------------------------------------------- *)
module Fun = struct
type t = EcModules.function_
Expand Down Expand Up @@ -2975,6 +3015,17 @@ module Theory = struct

in bind_base_th for1

(* ------------------------------------------------------------------ *)
let bind_std_th =
let for1 _path db = function
| Th_relation r ->
Some (Setoid.update_relation_db r db)
| Th_morphism m ->
Some (Setoid.update_morphism_db m db)
| _ -> None

in bind_base_th for1

(* ------------------------------------------------------------------ *)
let bind_nt_th =
let for1 path base = function
Expand Down Expand Up @@ -3022,12 +3073,14 @@ module Theory = struct
let env_tc = bind_tc_th thname env.env_tc items in
let env_rwbase = bind_br_th thname env.env_rwbase items in
let env_atbase = bind_at_th thname env.env_atbase items in
let env_stdbase = bind_std_th thname env.env_stdbase items in
let env_ntbase = bind_nt_th thname env.env_ntbase items in
let env_redbase = bind_rd_th thname env.env_redbase items in
let env =
{ env with
env_tci ; env_tc ; env_rwbase;
env_atbase; env_ntbase; env_redbase; }
env_atbase; env_stdbase; env_ntbase;
env_redbase; }
in
add_restr_th thname env items

Expand Down Expand Up @@ -3088,7 +3141,12 @@ module Theory = struct
| Th_baserw (x, _) ->
MC.import_rwbase (xpath x) env

| Th_addrw _ | Th_instance _ | Th_auto _ | Th_reduction _ ->
| Th_addrw _
| Th_instance _
| Th_auto _
| Th_reduction _
| Th_relation _
| Th_morphism _ ->
env

in
Expand All @@ -3105,7 +3163,7 @@ module Theory = struct
(* ------------------------------------------------------------------ *)
let rec filter clears root cleared items =
snd_map (List.pmap identity)
(List.map_fold (filter1 clears root) cleared items)
(List.fold_left_map (filter1 clears root) cleared items)

and filter_th clears root cleared items =
let mempty = List.exists (EcPath.p_equal root) clears in
Expand Down Expand Up @@ -3241,6 +3299,7 @@ module Theory = struct
env_tc = bind_tc_th thpath env.env_tc cth.cth_items;
env_rwbase = bind_br_th thpath env.env_rwbase cth.cth_items;
env_atbase = bind_at_th thpath env.env_atbase cth.cth_items;
env_stdbase = bind_std_th thpath env.env_stdbase cth.cth_items;
env_ntbase = bind_nt_th thpath env.env_ntbase cth.cth_items;
env_redbase = bind_rd_th thpath env.env_redbase cth.cth_items;
env_thenvs = Mp.set_union env.env_thenvs compiled.compiled; }
Expand Down Expand Up @@ -3444,7 +3503,7 @@ module LDecl = struct
let do1 hyps s =
let id = fresh_id hyps s in
(add_local id (LD_var (tbool, None)) hyps, id)
in List.map_fold do1 hyps names
in List.fold_left_map do1 hyps names

(* ------------------------------------------------------------------ *)
type hyps = {
Expand Down
15 changes: 15 additions & 0 deletions src/ecEnv.mli
Original file line number Diff line number Diff line change
Expand Up @@ -413,6 +413,21 @@ module Reduction : sig
val get : topsym -> env -> rule list
end

(* -------------------------------------------------------------------- *)
type setoid1 = {
spec : path;
morphisms : (path EcMaps.Mint.t) Mp.t;
}

module Setoid : sig
type nonrec setoid1 = setoid1

val add_relation : path * path -> env -> env
val get_relation : env -> path -> setoid1 option

val add_morphism : path * path * path * int -> env -> env
end

(* -------------------------------------------------------------------- *)
module Auto : sig
type base0 = path * [`Rigid | `Default]
Expand Down
50 changes: 39 additions & 11 deletions src/ecHiGoal.ml
Original file line number Diff line number Diff line change
Expand Up @@ -249,10 +249,11 @@ module LowRewrite = struct
| LRW_IdRewriting
| LRW_RPatternNoMatch
| LRW_RPatternNoRuleMatch
| LRW_InvalidSetoidContext

exception RewriteError of error

let rec find_rewrite_patterns ~inpred (dir : rwside) pt =
let rec find_rewrite_patterns ~inpred (dir : rwside) pt : (pt_ev * rwmode * (form * form)) list =
let hyps = pt.PT.ptev_env.PT.pte_hy in
let env = LDecl.toenv hyps in
let pt = { pt with ptev_ax = snd (PT.concretize pt) } in
Expand All @@ -270,7 +271,13 @@ module LowRewrite = struct
let pt' = apply_pterm_to_arg_r pt' (PVASub pt) in
[(pt', `Eq, (f, f_false))]

| _ -> []
| _ -> begin
try
EcSetoid.as_instance env (destr_op_app ax)
|> Option.map (fun (instance, (f1, f2)) -> (pt, `Setoid instance, (f1, f2)))
|> Option.to_list
with DestrError _ -> []
end

and split ax =
match EcFol.sform_of_form ax with
Expand Down Expand Up @@ -329,7 +336,7 @@ module LowRewrite = struct
type rwinfos = rwside * EcFol.form option * EcMatching.occ option

let t_rewrite_r ?(mode = `Full) ?target ((s, prw, o) : rwinfos) pt tc =
let hyps, tgfp = FApi.tc1_flat ?target tc in
let env, hyps, tgfp = FApi.tc1_eflat ?target tc in

let modes =
match mode with
Expand Down Expand Up @@ -366,7 +373,7 @@ module LowRewrite = struct
| PT.FindOccFailure `MatchFailure ->
raise (RewriteError LRW_RPatternNoRuleMatch)
| PT.FindOccFailure `IncompleteMatch ->
raise (RewriteError LRW_CannotInfer)
raise (RewriteError LRW_CannotInfer)
end in

if not occmode.k_keyed then begin
Expand All @@ -376,23 +383,42 @@ module LowRewrite = struct
end;

let pt = fst (PT.concretize pt) in

let exception InvalidSetoidContext in

let cpos =
let postcheck (instance : EcSetoid.instance) (lazy ctxt) =
if not (EcSetoid.valid_setoid_context env instance ctxt) then
raise InvalidSetoidContext in

let postcheck =
match mode with
| `Setoid instance -> postcheck instance
| _ -> fun _ -> () in

try FPosition.select_form
~postcheck:(fun _ ctxt _ -> postcheck ctxt; true)
~xconv:`AlphaEq ~keyed:occmode.k_keyed
hyps o subf tgfp
with InvalidOccurence -> raise (RewriteError (LRW_InvalidOccurence))
with
| InvalidOccurence ->
raise (RewriteError (LRW_InvalidOccurence))
| InvalidSetoidContext ->
raise (RewriteError (LRW_InvalidSetoidContext))
in

EcLowGoal.t_rewrite
~keyed:occmode.k_keyed ?target ~mode pt (s, Some cpos) tc in

let rec do_first = function
| [] -> raise (RewriteError LRW_NothingToRewrite)
| [] ->
raise (RewriteError LRW_NothingToRewrite)

| [pt, mode, (f1, f2)] ->
for1 (pt, mode, (f1, f2))

| (pt, mode, (f1, f2)) :: pts ->
try for1 (pt, mode, (f1, f2))
with RewriteError _ ->
do_first pts
| pt :: pts ->
try do_first [pt] with RewriteError _ -> do_first pts
in

let pts = find_rewrite_patterns s pt in
Expand Down Expand Up @@ -590,6 +616,8 @@ let process_rewrite1_core ?mode ?(close = true) ?target (s, p, o) pt tc =
tc_error !!tc "r-pattern does not match the goal"
| LowRewrite.LRW_RPatternNoRuleMatch ->
tc_error !!tc "r-pattern does not match the rewriting rule"
| LowRewrite.LRW_InvalidSetoidContext ->
tc_error !!tc "invalid setoid-rewrite position"

(* -------------------------------------------------------------------- *)
let process_delta ~und_delta ?target (s, o, p) tc =
Expand Down Expand Up @@ -2118,7 +2146,7 @@ let process_exists args (tc : tcenv1) =
PT.check_pterm_arg pte (x, xty) f arg.ptea_arg
in

let _concl, args = List.map_fold for1 (FApi.tc1_goal tc) args in
let _concl, args = List.fold_left_map for1 (FApi.tc1_goal tc) args in

if not (PT.can_concretize pte) then
tc_error !!tc "cannot infer all placeholders";
Expand Down
Loading