Library mathcomp.ssreflect.ssrAC

From HB Require Import structures.
Require Import BinPos BinNat.
use # [warning="-hiding-delimiting-key" ] attribute once we require Coq 8.18
From mathcomp Require Import ssreflect ssrbool ssrfun ssrnat eqtype seq bigop.
Set Implicit Arguments.

Small Scale Rewriting using Associativity and Commutativity Rewriting with AC (not modulo AC), using a small scale command. Replaces opA, opC, opAC, opCA, ... and any combinations of them Usage : rewrite [pattern](AC patternshape reordering) rewrite [pattern](ACl reordering) rewrite [pattern](ACof reordering reordering) rewrite [pattern]op. [AC patternshape reordering] rewrite [pattern]op. [ACl reordering] rewrite [pattern]op. [ACof reordering reordering]
  • if op is specified, the rule is specialized to op otherwise, the head symbol is a generic comm_law and the rewrite might be less efficient NOTE because of a bug in Coq's notations coq/coq#8190 op must not contain any hole. *%R. [AC p s] currently does not work because of that (@GRing.mul R). [AC p s] must be used instead
  • pattern is optional, as usual, but must be used to select the appropriate operator in case of ambiguity such an operator must have a canonical Monoid.com_law structure (additions, multiplications, conjunction and disjunction do)
  • patternshape is expressed using the syntax p := n | p * p' where "*" is purely formal and n > 0 is the number of left associated symbols examples of pattern shapes: + 4 represents (n * m * p * q) + (1*2) represents (n * (m * p))
  • reordering is expressed using the syntax s := n | s * s' where "*" is purely formal and n > 0 is the position in the LHS positions start at 1 !
If the ACl variant is used, the patternshape defaults to the pattern fully associated to the left i.e. n i.e (x * y * ...) Examples of reorderings:
  • ACl ((1*2)*3) is the identity (and will fail with error message)
  • opAC == op. [ACl (1*3)*2] == op. [AC 3 ((1*3)*2) ]
  • opCA == op. [AC (2*1) (1*2*3) ]
  • opACA == op. [AC (2*2) ((1*3)*(2*4)) ]
  • rewrite opAC -opA == rewrite op. [ACl 1*(3*2) ]
...

Declare Scope AC_scope.

Delimit Scope AC_scope with AC.

Definition change_type ty ty' (x : ty) (strategy : ty = ty') : ty' :=
 ecast ty ty strategy x.
Notation simplrefl := (ltac: (simpl; reflexivity)) (only parsing).
Notation cbvrefl := (ltac: (cbv; reflexivity)) (only parsing).
Notation vmrefl := (ltac: (vm_compute; reflexivity)) (only parsing).

Module AC.


Inductive syntax := Leaf of positive | Op of syntax & syntax.
Coercion serial := (fix loop (acc : seq positive) (s : syntax) :=
   match s with
   | Leaf nn :: acc
   | Op s s' ⇒ (loop^~ s (loop^~ s' acc))
   end) [::].

Lemma serial_Op s1 s2 : Op s1 s2 = s1 ++ s2 :> seq _.

Definition Leaf_of_nat n := Leaf ((pos_of_nat n n) - 1)%positive.

Module Import Syntax.
Bind Scope AC_scope with syntax.
Number Notation positive Pos.of_num_int Pos.to_num_uint : AC_scope.
Coercion Leaf : positive >-> syntax.
Coercion Leaf_of_nat : nat >-> syntax.
Notation "x * y" := (Op x%AC y%AC) : AC_scope.
End Syntax.

Definition pattern (s : syntax) := ((fix loop n s :=
  match s with
  | Leaf 1%positive(Leaf n, Pos.succ n)
  | Leaf mPos.iter (fun oi(Op oi.1 (Leaf oi.2), Pos.succ oi.2))
                       (Leaf n, Pos.succ n) (m - 1)%positive
  | Op s s'let: (p, n') := loop n s in
               let: (p', n'') := loop n' s' in
               (Op p p', n'')
  end) 1%positive s).1.

Section eval.
Variables (T : Type) (idx : T) (op : T T T).
Inductive env := Empty | ENode of T & env & env.
Definition pos := fix loop (e : env) p {struct e} :=
  match e, p with
 | ENode t _ _, 1%positivet
 | ENode t e _, (p~0)%positiveloop e p
 | ENode t _ e, (p~1)%positiveloop e p
 | _, _idx
end.

Definition set_pos (f : T T) := fix loop e p {struct p} :=
  match e, p with
 | ENode t e e', 1%positiveENode (f t) e e'
 | ENode t e e', (p~0)%positiveENode t (loop e p) e'
 | ENode t e e', (p~1)%positiveENode t e (loop e' p)
 | Empty, 1%positiveENode (f idx) Empty Empty
 | Empty, (p~0)%positiveENode idx (loop Empty p) Empty
 | Empty, (p~1)%positiveENode idx Empty (loop Empty p)
 end.

Lemma pos_set_pos (f : T T) e (p p' : positive) :
  pos (set_pos f e p) p' = if p == p' then f (pos e p) else pos e p'.

Fixpoint unzip z (e : env) : env := match z with
 | [::]e
 | (x, inl e') :: z'unzip z' (ENode x e' e)
 | (x, inr e') :: z'unzip z' (ENode x e e')
end.

Definition set_pos_trec (f : T T) := fix loop z e p {struct p} :=
  match e, p with
 | ENode t e e', 1%positiveunzip z (ENode (f t) e e')
 | ENode t e e', (p~0)%positiveloop ((t, inr e') :: z) e p
 | ENode t e e', (p~1)%positiveloop ((t, inl e) :: z) e' p
 | Empty, 1%positiveunzip z (ENode (f idx) Empty Empty)
 | Empty, (p~0)%positiveloop ((idx, (inr Empty)) :: z) Empty p
 | Empty, (p~1)%positiveloop ((idx, (inl Empty)) :: z) Empty p
 end.

Lemma set_pos_trecE f z e p : set_pos_trec f z e p = unzip z (set_pos f e p).

Definition eval (e : env) := fix loop (s : syntax) :=
match s with
  | Leaf npos e n
  | Op s s'op (loop s) (loop s')
end.
End eval.
Arguments Empty {T}.

Definition content := (fix loop (acc : env N) s :=
  match s with
  | Leaf nset_pos_trec 0%num N.succ [::] acc n
  | Op s s'loop (loop acc s') s
  end) Empty.

Lemma count_memE x (t : syntax) : count_mem x t = pos 0%num (content t) x.

Definition cforall N T : env N (env T Type) Type := env_rect (@^~ Empty)
  (fun _ e IHe e' IHe' R x, IHe (fun xeIHe' (R \o ENode x xe))).

Lemma cforallP N T R : ( e : env T, R e) (e : env N), cforall e R.

Section eq_eval.
Variables (T : Type) (idx : T) (op : Monoid.com_law idx).

Lemma proof (p s : syntax) : content p = content s
   env, eval idx op env p = eval idx op env s.

Definition direct p s ps := cforallP (@proof p s ps) (content p).

End eq_eval.

Module Exports.
Export AC.Syntax.
End Exports.
End AC.
Export AC.Exports.

Notation AC_check_pattern :=
  (ltac: (match goal with
    |- AC.content ?pat = AC.content ?ord
      let pat' := fresh "pat" in let pat' := eval compute in pat in
      tryif unify pat' ord then
           fail 1 "AC: equality between" pat
                  "and" ord "is trivial, cannot progress"
      else tryif vm_compute; reflexivity then idtac
      else fail 2 "AC: mismatch between shape" pat "=" pat' "and reordering" ord
    | |- ?Gfail 3 "AC: no pattern to check" G
  end))
  (only parsing).

Notation opACof law p s :=
((fun T idx op assoc lid rid comm ⇒ (change_type (@AC.direct T idx
   (Monoid.ComLaw.Pack (* FIXME: find a way to make this robust to hierarchy evolutions *)
      (Monoid.ComLaw.Class
         (SemiGroup.isLaw.Axioms_ op assoc)
         (Monoid.isMonoidLaw.Axioms_ idx op lid rid)
         (SemiGroup.isCommutativeLaw.Axioms_ op comm)))
   p%AC s%AC AC_check_pattern) cbvrefl)) _ _ law
(Monoid.mulmA _) (Monoid.mul1m _) (Monoid.mulm1 _) (Monoid.mulmC _))
(only parsing).

Notation opAC op p s := (opACof op (AC.pattern p%AC) s%AC) (only parsing).
Notation opACl op s := (opAC op (AC.Leaf_of_nat (size (AC.serial s%AC))) s%AC)
  (only parsing).

Notation "op .[ 'ACof' p s ]" := (opACof op p%AC s%AC)
  (at level 2, p at level 1, left associativity, only parsing).
Notation "op .[ 'AC' p s ]" := (opAC op p%AC s%AC)
  (at level 2, p at level 1, left associativity, only parsing).
Notation "op .[ 'ACl' s ]" := (opACl op s%AC)
  (at level 2, left associativity, only parsing).

Notation AC_strategy :=
  (ltac: (cbv -[Monoid.ComLaw.sort Monoid.Law.sort]; reflexivity))
  (only parsing).
Notation ACof p s := (change_type
  (@AC.direct _ _ _ p%AC s%AC AC_check_pattern) AC_strategy)
  (only parsing).
Notation AC p s := (ACof (AC.pattern p%AC) s%AC) (only parsing).
Notation ACl s := (AC (AC.Leaf_of_nat (size (AC.serial s%AC))) s%AC)
  (only parsing).