Library mathcomp.ssreflect.ssrAC
Require Import BinPos BinNat.
From mathcomp Require Import ssreflect ssrbool ssrfun ssrnat eqtype seq bigop.
Set Implicit Arguments.
From mathcomp Require Import ssreflect ssrbool ssrfun ssrnat eqtype seq bigop.
Set Implicit Arguments.
Small Scale Rewriting using Associatity and Commutativity
Rewriting with AC (not modulo AC), using a small scale command.
Replaces opA, opC, opAC, opCA, ... and any combinations of them
Usage :
rewrite [pattern](AC patternshape reordering)
rewrite [pattern](ACl reordering)
rewrite [pattern](ACof reordering reordering)
rewrite [pattern]op. [AC patternshape reordering]
rewrite [pattern]op. [ACl reordering]
rewrite [pattern]op. [ACof reordering reordering]
If the ACl variant is used, the patternshape defaults to the
pattern fully associated to the left i.e. n i.e (x * y * ...)
Examples of reorderings:
- if op is specified, the rule is specialized to op
otherwise, the head symbol is a generic comm_law
and the rewrite might be less efficient
NOTE because of a bug in Coq's notations coq/coq#8190
op must not contain any hole.
*%R. [AC p s] currently does not work because of that
(@GRing.mul R). [AC p s] must be used instead
- pattern is optional, as usual, but must be used to select the
appropriate operator in case of ambiguity such an operator must
have a canonical Monoid.com_law structure
(additions, multiplications, conjuction and disjunction do)
- patternshape is expressed using the syntax
p := n | p * p'
where "*" is purely formal
and n > 0 is number of left associated symbols
examples of pattern shapes:
+ 4 represents (n * m * p * q)
+ (1*2) represents (n * (m * p))
- reordering is expressed using the syntax s := n | s * s' where "*" is purely formal and n > 0 is the position in the LHS positions start at 1 !
- ACl ((1*2)*3) is the identity (and will fail with error message)
- opAC == op. [ACl (1*3)*2] == op. [AC 3 ((1*3)*2) ]
- opCA == op. [AC (2*1) (1*2*3) ]
- opACA == op. [AC (2*2) ((1*3)*(2*4)) ]
- rewrite opAC -opA == rewrite op. [ACl 1*(3*2) ]
Delimit Scope AC_scope with AC.
Definition change_type ty ty' (x : ty) (strategy : ty = ty') : ty' :=
ecast ty ty strategy x.
Notation simplrefl := (ltac: (simpl; reflexivity)).
Notation cbvrefl := (ltac: (cbv; reflexivity)).
Notation vmrefl := (ltac: (vm_compute; reflexivity)).
Module AC.
Canonical positive_eqType := EqType positive
(EqMixin (fun _ _ ⇒ equivP idP (Pos.eqb_eq _ _))).
Should be replaced by (EqMixin Pos.eqb_spec) for coq >= 8.7
Inductive syntax := Leaf of positive | Op of syntax & syntax.
Coercion serial := (fix loop (acc : seq positive) (s : syntax) :=
match s with
| Leaf n ⇒ n :: acc
| Op s s' ⇒ (loop^~ s (loop^~ s' acc))
end) [::].
Lemma serial_Op s1 s2 : Op s1 s2 = s1 ++ s2 :> seq _.
Definition Leaf_of_nat n := Leaf ((pos_of_nat n n) - 1)%positive.
Module Import Syntax.
Coercion Leaf : positive >-> syntax.
Coercion Leaf_of_nat : nat >-> syntax.
Notation "1" := 1%positive : AC_scope.
Notation "x * y" := (Op x%AC y%AC) : AC_scope.
End Syntax.
Definition pattern (s : syntax) := ((fix loop n s :=
match s with
| Leaf 1%positive ⇒ (Leaf n, Pos.succ n)
| Leaf m ⇒ Pos.iter (fun oi ⇒ (Op oi.1 (Leaf oi.2), Pos.succ oi.2))
(Leaf n, Pos.succ n) (m - 1)%positive
| Op s s' ⇒ let: (p, n') := loop n s in
let: (p', n'') := loop n' s' in
(Op p p', n'')
end) 1%positive s).1.
Section eval.
Variables (T : Type) (idx : T) (op : T → T → T).
Inductive env := Empty | ENode of T & env & env.
Definition pos := fix loop (e : env) p {struct e} :=
match e, p with
| ENode t _ _, 1%positive ⇒ t
| ENode t e _, (p~0)%positive ⇒ loop e p
| ENode t _ e, (p~1)%positive ⇒ loop e p
| _, _ ⇒ idx
end.
Definition set_pos (f : T → T) := fix loop e p {struct p} :=
match e, p with
| ENode t e e', 1%positive ⇒ ENode (f t) e e'
| ENode t e e', (p~0)%positive ⇒ ENode t (loop e p) e'
| ENode t e e', (p~1)%positive ⇒ ENode t e (loop e' p)
| Empty, 1%positive ⇒ ENode (f idx) Empty Empty
| Empty, (p~0)%positive ⇒ ENode idx (loop Empty p) Empty
| Empty, (p~1)%positive ⇒ ENode idx Empty (loop Empty p)
end.
Lemma pos_set_pos (f : T → T) e (p p' : positive) :
pos (set_pos f e p) p' = if p == p' then f (pos e p) else pos e p'.
Fixpoint unzip z (e : env) : env := match z with
| [::] ⇒ e
| (x, inl e') :: z' ⇒ unzip z' (ENode x e' e)
| (x, inr e') :: z' ⇒ unzip z' (ENode x e e')
end.
Definition set_pos_trec (f : T → T) := fix loop z e p {struct p} :=
match e, p with
| ENode t e e', 1%positive ⇒ unzip z (ENode (f t) e e')
| ENode t e e', (p~0)%positive ⇒ loop ((t, inr e') :: z) e p
| ENode t e e', (p~1)%positive ⇒ loop ((t, inl e) :: z) e' p
| Empty, 1%positive ⇒ unzip z (ENode (f idx) Empty Empty)
| Empty, (p~0)%positive ⇒ loop ((idx, (inr Empty)) :: z) Empty p
| Empty, (p~1)%positive ⇒ loop ((idx, (inl Empty)) :: z) Empty p
end.
Lemma set_pos_trecE f z e p : set_pos_trec f z e p = unzip z (set_pos f e p).
Definition eval (e : env) := fix loop (s : syntax) :=
match s with
| Leaf n ⇒ pos e n
| Op s s' ⇒ op (loop s) (loop s')
end.
End eval.
Definition content := (fix loop (acc : env N) s :=
match s with
| Leaf n ⇒ set_pos_trec 0%num N.succ [::] acc n
| Op s s' ⇒ loop (loop acc s') s
end) Empty.
Lemma count_memE x (t : syntax) : count_mem x t = pos 0%num (content t) x.
Definition cforall N T : env N → (env T → Type) → Type := env_rect (@^~ Empty)
(fun _ e IHe e' IHe' R ⇒ ∀ x, IHe (fun xe ⇒ IHe' (R \o ENode x xe))).
Lemma cforallP N T R : (∀ e : env T, R e) → ∀ (e : env N), cforall e R.
Section eq_eval.
Variables (T : Type) (idx : T) (op : Monoid.com_law idx).
Lemma proof (p s : syntax) : content p = content s →
∀ env, eval idx op env p = eval idx op env s.
Definition direct p s ps := cforallP (@proof p s ps) (content p).
End eq_eval.
Module Exports.
Export AC.Syntax.
End Exports.
End AC.
Export AC.Exports.
Notation AC_check_pattern :=
(ltac: (match goal with
|- AC.content ?pat = AC.content ?ord ⇒
let pat' := fresh "pat" in let pat' := eval compute in pat in
tryif unify pat' ord then
fail 1 "AC: equality between" pat
"and" ord "is trivial, cannot progress"
else tryif vm_compute; reflexivity then idtac
else fail 2 "AC: mismatch between shape" pat "=" pat' "and reordering" ord
| |- ?G ⇒ fail 3 "AC: no pattern to check" G
end)).
Notation opACof law p s :=
((fun T idx op assoc lid rid comm ⇒ (change_type (@AC.direct T idx
(@Monoid.ComLaw _ _ (@Monoid.Law _ idx op assoc lid rid) comm)
p%AC s%AC AC_check_pattern) cbvrefl)) _ _ law
(Monoid.mulmA _) (Monoid.mul1m _) (Monoid.mulm1 _) (Monoid.mulmC _)).
Notation opAC op p s := (opACof op (AC.pattern p%AC) s%AC).
Notation opACl op s := (opAC op (AC.Leaf_of_nat (size (AC.serial s%AC))) s%AC).
Notation "op .[ 'ACof' p s ]" := (opACof op p s)
(at level 2, p at level 1, left associativity).
Notation "op .[ 'AC' p s ]" := (opAC op p s)
(at level 2, p at level 1, left associativity).
Notation "op .[ 'ACl' s ]" := (opACl op s)
(at level 2, left associativity).
Notation AC_strategy :=
(ltac: (cbv -[Monoid.com_operator Monoid.operator]; reflexivity)).
Notation ACof p s := (change_type
(@AC.direct _ _ _ p%AC s%AC AC_check_pattern) AC_strategy).
Notation AC p s := (ACof (AC.pattern p%AC) s%AC).
Notation ACl s := (AC (AC.Leaf_of_nat (size (AC.serial s%AC))) s%AC).