Library mathcomp.ssreflect.ssrAC

Require Import BinPos BinNat.
From mathcomp Require Import ssreflect ssrbool ssrfun ssrnat eqtype seq bigop.
Set Implicit Arguments.

Small Scale Rewriting using Associatity and Commutativity
Rewriting with AC (not modulo AC), using a small scale command. Replaces opA, opC, opAC, opCA, ... and any combinations of them
Usage : rewrite [pattern](AC patternshape reordering) rewrite [pattern](ACl reordering) rewrite [pattern](ACof reordering reordering) rewrite [pattern]op. [AC patternshape reordering] rewrite [pattern]op. [ACl reordering] rewrite [pattern]op. [ACof reordering reordering]
  • if op is specified, the rule is specialized to op otherwise, the head symbol is a generic comm_law and the rewrite might be less efficient NOTE because of a bug in Coq's notations coq/coq#8190 op must not contain any hole. *%R. [AC p s] currently does not work because of that (@GRing.mul R). [AC p s] must be used instead
  • pattern is optional, as usual, but must be used to select the appropriate operator in case of ambiguity such an operator must have a canonical Monoid.com_law structure (additions, multiplications, conjuction and disjunction do)
  • patternshape is expressed using the syntax p := n | p * p' where "*" is purely formal and n > 0 is number of left associated symbols examples of pattern shapes: + 4 represents (n * m * p * q) + (1*2) represents (n * (m * p))
  • reordering is expressed using the syntax s := n | s * s' where "*" is purely formal and n > 0 is the position in the LHS positions start at 1 !
If the ACl variant is used, the patternshape defaults to the pattern fully associated to the left i.e. n i.e (x * y * ...)
Examples of reorderings:
  • ACl ((1*2)*3) is the identity (and will fail with error message)
  • opAC == op. [ACl (1*3)*2] == op. [AC 3 ((1*3)*2) ]
  • opCA == op. [AC (2*1) (1*2*3) ]
  • opACA == op. [AC (2*2) ((1*3)*(2*4)) ]
  • rewrite opAC -opA == rewrite op. [ACl 1*(3*2) ]

Delimit Scope AC_scope with AC.

Definition change_type ty ty' (x : ty) (strategy : ty = ty') : ty' :=
 ecast ty ty strategy x.
Notation simplrefl := (ltac: (simpl; reflexivity)).
Notation cbvrefl := (ltac: (cbv; reflexivity)).
Notation vmrefl := (ltac: (vm_compute; reflexivity)).

Module AC.

Canonical positive_eqType := EqType positive
   (EqMixin (fun _ _equivP idP (Pos.eqb_eq _ _))).
Should be replaced by (EqMixin Pos.eqb_spec) for coq >= 8.7

Inductive syntax := Leaf of positive | Op of syntax & syntax.
Coercion serial := (fix loop (acc : seq positive) (s : syntax) :=
   match s with
   | Leaf nn :: acc
   | Op s s' ⇒ (loop^~ s (loop^~ s' acc))
   end) [::].

Lemma serial_Op s1 s2 : Op s1 s2 = s1 ++ s2 :> seq _.

Definition Leaf_of_nat n := Leaf ((pos_of_nat n n) - 1)%positive.

Module Import Syntax.
Coercion Leaf : positive >-> syntax.
Coercion Leaf_of_nat : nat >-> syntax.
Notation "1" := 1%positive : AC_scope.
Notation "x * y" := (Op x%AC y%AC) : AC_scope.
End Syntax.

Definition pattern (s : syntax) := ((fix loop n s :=
  match s with
  | Leaf 1%positive(Leaf n, Pos.succ n)
  | Leaf mPos.iter (fun oi(Op oi.1 (Leaf oi.2), Pos.succ oi.2))
                       (Leaf n, Pos.succ n) (m - 1)%positive
  | Op s s'let: (p, n') := loop n s in
               let: (p', n'') := loop n' s' in
               (Op p p', n'')
  end) 1%positive s).1.

Section eval.
Variables (T : Type) (idx : T) (op : T T T).
Inductive env := Empty | ENode of T & env & env.
Definition pos := fix loop (e : env) p {struct e} :=
  match e, p with
 | ENode t _ _, 1%positivet
 | ENode t e _, (p~0)%positiveloop e p
 | ENode t _ e, (p~1)%positiveloop e p
 | _, _idx

Definition set_pos (f : T T) := fix loop e p {struct p} :=
  match e, p with
 | ENode t e e', 1%positiveENode (f t) e e'
 | ENode t e e', (p~0)%positiveENode t (loop e p) e'
 | ENode t e e', (p~1)%positiveENode t e (loop e' p)
 | Empty, 1%positiveENode (f idx) Empty Empty
 | Empty, (p~0)%positiveENode idx (loop Empty p) Empty
 | Empty, (p~1)%positiveENode idx Empty (loop Empty p)

Lemma pos_set_pos (f : T T) e (p p' : positive) :
  pos (set_pos f e p) p' = if p == p' then f (pos e p) else pos e p'.

Fixpoint unzip z (e : env) : env := match z with
 | [::]e
 | (x, inl e') :: z'unzip z' (ENode x e' e)
 | (x, inr e') :: z'unzip z' (ENode x e e')

Definition set_pos_trec (f : T T) := fix loop z e p {struct p} :=
  match e, p with
 | ENode t e e', 1%positiveunzip z (ENode (f t) e e')
 | ENode t e e', (p~0)%positiveloop ((t, inr e') :: z) e p
 | ENode t e e', (p~1)%positiveloop ((t, inl e) :: z) e' p
 | Empty, 1%positiveunzip z (ENode (f idx) Empty Empty)
 | Empty, (p~0)%positiveloop ((idx, (inr Empty)) :: z) Empty p
 | Empty, (p~1)%positiveloop ((idx, (inl Empty)) :: z) Empty p

Lemma set_pos_trecE f z e p : set_pos_trec f z e p = unzip z (set_pos f e p).

Definition eval (e : env) := fix loop (s : syntax) :=
match s with
  | Leaf npos e n
  | Op s s'op (loop s) (loop s')
End eval.

Definition content := (fix loop (acc : env N) s :=
  match s with
  | Leaf nset_pos_trec 0%num N.succ [::] acc n
  | Op s s'loop (loop acc s') s
  end) Empty.

Lemma count_memE x (t : syntax) : count_mem x t = pos 0%num (content t) x.

Definition cforall N T : env N (env T Type) Type := env_rect (@^~ Empty)
  (fun _ e IHe e' IHe' R x, IHe (fun xeIHe' (R \o ENode x xe))).

Lemma cforallP N T R : ( e : env T, R e) (e : env N), cforall e R.

Section eq_eval.
Variables (T : Type) (idx : T) (op : Monoid.com_law idx).

Lemma proof (p s : syntax) : content p = content s
   env, eval idx op env p = eval idx op env s.

Definition direct p s ps := cforallP (@proof p s ps) (content p).

End eq_eval.

Module Exports.
Export AC.Syntax.
End Exports.
End AC.
Export AC.Exports.

Notation AC_check_pattern :=
  (ltac: (match goal with
    |- AC.content ?pat = AC.content ?ord
      let pat' := fresh "pat" in let pat' := eval compute in pat in
      tryif unify pat' ord then
           fail 1 "AC: equality between" pat
                  "and" ord "is trivial, cannot progress"
      else tryif vm_compute; reflexivity then idtac
      else fail 2 "AC: mismatch between shape" pat "=" pat' "and reordering" ord
    | |- ?Gfail 3 "AC: no pattern to check" G

Notation opACof law p s :=
((fun T idx op assoc lid rid comm ⇒ (change_type ( T idx
   (@Monoid.ComLaw _ _ (@Monoid.Law _ idx op assoc lid rid) comm)
   p%AC s%AC AC_check_pattern) cbvrefl)) _ _ law
(Monoid.mulmA _) (Monoid.mul1m _) (Monoid.mulm1 _) (Monoid.mulmC _)).

Notation opAC op p s := (opACof op (AC.pattern p%AC) s%AC).
Notation opACl op s := (opAC op (AC.Leaf_of_nat (size (AC.serial s%AC))) s%AC).

Notation "op .[ 'ACof' p s ]" := (opACof op p s)
  (at level 2, p at level 1, left associativity).
Notation "op .[ 'AC' p s ]" := (opAC op p s)
  (at level 2, p at level 1, left associativity).
Notation "op .[ 'ACl' s ]" := (opACl op s)
  (at level 2, left associativity).

Notation AC_strategy :=
  (ltac: (cbv -[Monoid.com_operator Monoid.operator]; reflexivity)).
Notation ACof p s := (change_type
  ( _ _ _ p%AC s%AC AC_check_pattern) AC_strategy).
Notation AC p s := (ACof (AC.pattern p%AC) s%AC).
Notation ACl s := (AC (AC.Leaf_of_nat (size (AC.serial s%AC))) s%AC).