Library mathcomp.algebra.matrix

(* (c) Copyright 2006-2016 Microsoft Corporation and Inria.                  
 Distributed under the terms of CeCILL-B.                                  *)

From mathcomp Require Import ssreflect ssrbool ssrfun eqtype ssrnat seq choice.
From mathcomp Require Import fintype finfun bigop finset fingroup perm.
From mathcomp Require Import div prime binomial ssralg finalg zmodp countalg.

Basic concrete linear algebra : definition of type for matrices, and all basic matrix operations including determinant, trace and support for block decomposition. Matrices are represented by a row-major list of their coefficients but this implementation is hidden by three levels of wrappers (Matrix/Finfun/Tuple) so the matrix type should be treated as abstract and handled using only the operations described below: 'M[R](m, n) == the type of m rows by n columns matrices with 'M(m, n) coefficients in R; the [R] is optional and is usually omitted. 'M[R]_n, 'M_n == the type of n x n square matrices. 'rV[R]_n, 'rV_n == the type of 1 x n row vectors. 'cV[R]_n, 'cV_n == the type of n x 1 column vectors. \matrix(i < m, j < n) Expr(i, j) == the m x n matrix with general coefficient Expr(i, j), with i : 'I_m and j : 'I_n. the < m bound can be omitted if it is equal to n, though usually both bounds are omitted as they can be inferred from the context. \row(j < n) Expr(j), \col(i < m) Expr(i) the row / column vectors with general term Expr; the parentheses can be omitted along with the bound. \matrix(i < m) RowExpr(i) == the m x n matrix with row i given by RowExpr(i) : 'rV_n. A i j == the coefficient of matrix A : 'M(m, n) in column j of row i, where i : 'I_m, and j : 'I_n (via the coercion fun_of_matrix : matrix >-> Funclass). const_mx a == the constant matrix whose entries are all a (dimensions should be determined by context). map_mx f A == the pointwise image of A by f, i.e., the matrix Af congruent to A with Af i j = f (A i j) for all i and j. A^T == the matrix transpose of A. row i A == the i'th row of A (this is a row vector). col j A == the j'th column of A (a column vector). row' i A == A with the i'th row spliced out. col' i A == A with the j'th column spliced out. xrow i1 i2 A == A with rows i1 and i2 interchanged. xcol j1 j2 A == A with columns j1 and j2 interchanged. row_perm s A == A : 'M(m, n) with rows permuted by s : 'S_m. col_perm s A == A : 'M(m, n) with columns permuted by s : 'S_n. row_mx Al Ar == the row block matrix <Al Ar> obtained by contatenating two matrices Al and Ar of the same height. col_mx Au Ad == the column block matrix / Au \ (Au and Ad must have the same width). \ Ad / block_mx Aul Aur Adl Adr == the block matrix / Aul Aur \ \ Adl Adr / [l|r]submx A == the left/right submatrices of a row block matrix A. Note that the type of A, 'M(m, n1 + n2) indicates how A should be decomposed. [u|d]submx A == the up/down submatrices of a column block matrix A. [u|d] [l|r]submx A == the upper left, etc submatrices of a block matrix A. castmx eq_mn A == A : 'M(m, n) cast to 'M(m', n') using the equation pair eq_mn : (m = m') * (n = n'). This is the usual workaround for the syntactic limitations of dependent types in Coq, and can be used to introduce a block decomposition. It simplifies to A when eq_mn is the pair (erefl m, erefl n) (using rewrite /castmx /=). conform_mx B A == A if A and B have the same dimensions, else B. mxvec A == a row vector of width m * n holding all the entries of the m x n matrix A. mxvec_index i j == the index of A i j in mxvec A. vec_mx v == the inverse of mxvec, reshaping a vector of width m * n back into into an m x n rectangular matrix. In 'M[R](m, n), R can be any type, but 'M[R](m, n) inherits the eqType, choiceType, countType, finType, zmodType structures of R; 'M[R](m, n) also has a natural lmodType R structure when R has a ringType structure. Because the type of matrices specifies their dimension, only non-trivial square matrices (of type 'M[R]_n.+1) can inherit the ring structure of R; indeed they then have an algebra structure (lalgType R, or algType R if R is a comRingType, or even unitAlgType if R is a comUnitRingType). We thus provide separate syntax for the general matrix multiplication, and other operations for matrices over a ringType R: A *m B == the matrix product of A and B; the width of A must be equal to the height of B. a%:M == the scalar matrix with a's on the main diagonal; in particular 1%:M denotes the identity matrix, and is is equal to 1%R when n is of the form n'.+1 (e.g., n >= 1). is_scalar_mx A <=> A is a scalar matrix (A = a%:M for some A). diag_mx d == the diagonal matrix whose main diagonal is d : 'rV_n. delta_mx i j == the matrix with a 1 in row i, column j and 0 elsewhere. pid_mx r == the partial identity matrix with 1s only on the r first coefficients of the main diagonal; the dimensions of pid_mx r are determined by the context, and pid_mx r can be rectangular. copid_mx r == the complement to 1%:M of pid_mx r: a square diagonal matrix with 1s on all but the first r coefficients on its main diagonal. perm_mx s == the n x n permutation matrix for s : 'S_n. tperm_mx i1 i2 == the permutation matrix that exchanges i1 i2 : 'I_n. is_perm_mx A == A is a permutation matrix. lift0_mx A == the 1 + n square matrix block_mx 1 0 0 A when A : 'M_n. \tr A == the trace of a square matrix A. \det A == the determinant of A, using the Leibnitz formula. cofactor i j A == the i, j cofactor of A (the signed i, j minor of A), \adj A == the adjugate matrix of A (\adj A i j = cofactor j i A). A \in unitmx == A is invertible (R must be a comUnitRingType). invmx A == the inverse matrix of A if A \in unitmx A, otherwise A. The following operations provide a correspondance between linear functions and matrices: lin1_mx f == the m x n matrix that emulates via right product a (linear) function f : 'rV_m -> 'rV_n on ROW VECTORS lin_mx f == the (m1 * n1) x (m2 * n2) matrix that emulates, via the right multiplication on the mxvec encodings, a linear function f : 'M(m1, n1) -> 'M(m2, n2) lin_mul_row u := lin1_mx (mulmx u \o vec_mx) (applies a row-encoded function to the row-vector u). mulmx A == partially applied matrix multiplication (mulmx A B is displayed as A *m B), with, for A : 'M(m, n), a canonical {linear 'M(n, p) -> 'M(m, p}} structure. mulmxr A == self-simplifying right-hand matrix multiplication, i.e., mulmxr A B simplifies to B *m A, with, for A : 'M(n, p), a canonical {linear 'M(m, n) -> 'M(m, p}} structure. lin_mulmx A := lin_mx (mulmx A). lin_mulmxr A := lin_mx (mulmxr A). We also extend any finType structure of R to 'M[R](m, n), and define: {'GL_n[R]} == the finGroupType of units of 'M[R]_n.-1.+1. 'GL_n[R] == the general linear group of all matrices in {'GL_n(R)}. 'GL_n(p) == 'GL_n['F_p], the general linear group of a prime field. GLval u == the coercion of u : {'GL_n(R)} to a matrix. In addition to the lemmas relevant to these definitions, this file also proves several classic results, including :
  • The determinant is a multilinear alternate form.
  • The Laplace determinant expansion formulas: expand_det [row|col].
  • The Cramer rule : mul_mx_adj & mul_adj_mx.
Finally, as an example of the use of block products, we program and prove the correctness of a classical linear algebra algorithm: cormenLUP A == the triangular decomposition (L, U, P) of a nontrivial square matrix A into a lower triagular matrix L with 1s on the main diagonal, an upper matrix U, and a permutation matrix P, such that P * A = L * U. This is example only; we use a different, more precise algorithm to develop the theory of matrix ranks and row spaces in mxalgebra.v

Set Implicit Arguments.

Import GroupScope.
Import GRing.Theory.
Local Open Scope ring_scope.

Reserved Notation "''M_' n" (at level 8, n at level 2, format "''M_' n").
Reserved Notation "''rV_' n" (at level 8, n at level 2, format "''rV_' n").
Reserved Notation "''cV_' n" (at level 8, n at level 2, format "''cV_' n").
Reserved Notation "''M_' ( n )" (at level 8, only parsing).
Reserved Notation "''M_' ( m , n )" (at level 8, format "''M_' ( m , n )").
Reserved Notation "''M[' R ]_ n" (at level 8, n at level 2, only parsing).
Reserved Notation "''rV[' R ]_ n" (at level 8, n at level 2, only parsing).
Reserved Notation "''cV[' R ]_ n" (at level 8, n at level 2, only parsing).
Reserved Notation "''M[' R ]_ ( n )" (at level 8, only parsing).
Reserved Notation "''M[' R ]_ ( m , n )" (at level 8, only parsing).

Reserved Notation "\matrix_ i E"
  (at level 36, E at level 36, i at level 2,
   format "\matrix_ i E").
Reserved Notation "\matrix_ ( i < n ) E"
  (at level 36, E at level 36, i, n at level 50, only parsing).
Reserved Notation "\matrix_ ( i , j ) E"
  (at level 36, E at level 36, i, j at level 50,
   format "\matrix_ ( i , j ) E").
Reserved Notation "\matrix[ k ]_ ( i , j ) E"
  (at level 36, E at level 36, i, j at level 50,
   format "\matrix[ k ]_ ( i , j ) E").
Reserved Notation "\matrix_ ( i < m , j < n ) E"
  (at level 36, E at level 36, i, m, j, n at level 50, only parsing).
Reserved Notation "\matrix_ ( i , j < n ) E"
  (at level 36, E at level 36, i, j, n at level 50, only parsing).
Reserved Notation "\row_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\row_ j E").
Reserved Notation "\row_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50, only parsing).
Reserved Notation "\col_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\col_ j E").
Reserved Notation "\col_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50, only parsing).

Reserved Notation "x %:M" (at level 8, format "x %:M").
Reserved Notation "A *m B" (at level 40, left associativity, format "A *m B").
Reserved Notation "A ^T" (at level 8, format "A ^T").
Reserved Notation "\tr A" (at level 10, A at level 8, format "\tr A").
Reserved Notation "\det A" (at level 10, A at level 8, format "\det A").
Reserved Notation "\adj A" (at level 10, A at level 8, format "\adj A").


Type Definition*********************************

Section MatrixDef.

Variable R : Type.
Variables m n : nat.

Basic linear algebra (matrices). We use dependent types (ordinals) for the indices so that ranges are mostly inferred automatically

Inductive matrix : predArgType := Matrix of {ffun 'I_m × 'I_n R}.

Definition mx_val A := let: Matrix g := A in g.

Canonical matrix_subType := Eval hnf in [newType for mx_val].

Fact matrix_key : unit.
Definition matrix_of_fun_def F := Matrix [ffun ij F ij.1 ij.2].
Definition matrix_of_fun k := locked_with k matrix_of_fun_def.
Canonical matrix_unlockable k := [unlockable fun matrix_of_fun k].

Definition fun_of_matrix A (i : 'I_m) (j : 'I_n) := mx_val A (i, j).

Coercion fun_of_matrix : matrix >-> Funclass.

Lemma mxE k F : matrix_of_fun k F =2 F.

Lemma matrixP (A B : matrix) : A =2 B A = B.

Lemma eq_mx k F1 F2 : (F1 =2 F2) matrix_of_fun k F1 = matrix_of_fun k F2.

End MatrixDef.



Notation "''M[' R ]_ ( m , n )" := (matrix R m n) (only parsing): type_scope.
Notation "''rV[' R ]_ n" := 'M[R]_(1, n) (only parsing) : type_scope.
Notation "''cV[' R ]_ n" := 'M[R]_(n, 1) (only parsing) : type_scope.
Notation "''M[' R ]_ n" := 'M[R]_(n, n) (only parsing) : type_scope.
Notation "''M[' R ]_ ( n )" := 'M[R]_n (only parsing) : type_scope.
Notation "''M_' ( m , n )" := 'M[_]_(m, n) : type_scope.
Notation "''rV_' n" := 'M_(1, n) : type_scope.
Notation "''cV_' n" := 'M_(n, 1) : type_scope.
Notation "''M_' n" := 'M_(n, n) : type_scope.
Notation "''M_' ( n )" := 'M_n (only parsing) : type_scope.

Notation "\matrix[ k ]_ ( i , j ) E" := (matrix_of_fun k (fun i jE))
  (at level 36, E at level 36, i, j at level 50): ring_scope.

Notation "\matrix_ ( i < m , j < n ) E" :=
  (@matrix_of_fun _ m n matrix_key (fun i jE)) (only parsing) : ring_scope.

Notation "\matrix_ ( i , j < n ) E" :=
  (\matrix_(i < n, j < n) E) (only parsing) : ring_scope.

Notation "\matrix_ ( i , j ) E" := (\matrix_(i < _, j < _) E) : ring_scope.

Notation "\matrix_ ( i < m ) E" :=
  (\matrix_(i < m, j < _) @fun_of_matrix _ 1 _ E 0 j)
  (only parsing) : ring_scope.
Notation "\matrix_ i E" := (\matrix_(i < _) E) : ring_scope.

Notation "\col_ ( i < n ) E" := (@matrix_of_fun _ n 1 matrix_key (fun i _E))
  (only parsing) : ring_scope.
Notation "\col_ i E" := (\col_(i < _) E) : ring_scope.

Notation "\row_ ( j < n ) E" := (@matrix_of_fun _ 1 n matrix_key (fun _ jE))
  (only parsing) : ring_scope.
Notation "\row_ j E" := (\row_(j < _) E) : ring_scope.

Definition matrix_eqMixin (R : eqType) m n :=
  Eval hnf in [eqMixin of 'M[R]_(m, n) by <:].
Canonical matrix_eqType (R : eqType) m n:=
  Eval hnf in EqType 'M[R]_(m, n) (matrix_eqMixin R m n).
Definition matrix_choiceMixin (R : choiceType) m n :=
  [choiceMixin of 'M[R]_(m, n) by <:].
Canonical matrix_choiceType (R : choiceType) m n :=
  Eval hnf in ChoiceType 'M[R]_(m, n) (matrix_choiceMixin R m n).
Definition matrix_countMixin (R : countType) m n :=
  [countMixin of 'M[R]_(m, n) by <:].
Canonical matrix_countType (R : countType) m n :=
  Eval hnf in CountType 'M[R]_(m, n) (matrix_countMixin R m n).
Canonical matrix_subCountType (R : countType) m n :=
  Eval hnf in [subCountType of 'M[R]_(m, n)].
Definition matrix_finMixin (R : finType) m n :=
  [finMixin of 'M[R]_(m, n) by <:].
Canonical matrix_finType (R : finType) m n :=
  Eval hnf in FinType 'M[R]_(m, n) (matrix_finMixin R m n).
Canonical matrix_subFinType (R : finType) m n :=
  Eval hnf in [subFinType of 'M[R]_(m, n)].

Lemma card_matrix (F : finType) m n : (#|{: 'M[F]_(m, n)}| = #|F| ^ (m × n))%N.

Matrix structural operations (transpose, permutation, blocks) *******

Section MatrixStructural.

Variable R : Type.

Constant matrix
Fact const_mx_key : unit.
Definition const_mx m n a : 'M[R]_(m, n) := \matrix[const_mx_key]_(i, j) a.

Section FixedDim.
Definitions and properties for which we can work with fixed dimensions.

Variables m n : nat.
Implicit Type A : 'M[R]_(m, n).

Reshape a matrix, to accomodate the block functions for instance.
Definition castmx m' n' (eq_mn : (m = m') × (n = n')) A : 'M_(m', n') :=
  let: erefl in _ = m' := eq_mn.1 return 'M_(m', n') in
  let: erefl in _ = n' := eq_mn.2 return 'M_(m, n') in A.

Definition conform_mx m' n' B A :=
  match m =P m', n =P n' with
  | ReflectT eq_m, ReflectT eq_ncastmx (eq_m, eq_n) A
  | _, _B
  end.

Transpose a matrix
Fact trmx_key : unit.
Definition trmx A := \matrix[trmx_key]_(i, j) A j i.

Permute a matrix vertically (rows) or horizontally (columns)
Fact row_perm_key : unit.
Definition row_perm (s : 'S_m) A := \matrix[row_perm_key]_(i, j) A (s i) j.
Fact col_perm_key : unit.
Definition col_perm (s : 'S_n) A := \matrix[col_perm_key]_(i, j) A i (s j).

Exchange two rows/columns of a matrix
Definition xrow i1 i2 := row_perm (tperm i1 i2).
Definition xcol j1 j2 := col_perm (tperm j1 j2).

Row/Column sub matrices of a matrix
Definition row i0 A := \row_j A i0 j.
Definition col j0 A := \col_i A i j0.

Removing a row/column from a matrix
Definition row' i0 A := \matrix_(i, j) A (lift i0 i) j.
Definition col' j0 A := \matrix_(i, j) A i (lift j0 j).

Lemma castmx_const m' n' (eq_mn : (m = m') × (n = n')) a :
  castmx eq_mn (const_mx a) = const_mx a.

Lemma trmx_const a : trmx (const_mx a) = const_mx a.

Lemma row_perm_const s a : row_perm s (const_mx a) = const_mx a.

Lemma col_perm_const s a : col_perm s (const_mx a) = const_mx a.

Lemma xrow_const i1 i2 a : xrow i1 i2 (const_mx a) = const_mx a.

Lemma xcol_const j1 j2 a : xcol j1 j2 (const_mx a) = const_mx a.

Lemma rowP (u v : 'rV[R]_n) : u 0 =1 v 0 u = v.

Lemma rowK u_ i0 : row i0 (\matrix_i u_ i) = u_ i0.

Lemma row_matrixP A B : ( i, row i A = row i B) A = B.

Lemma colP (u v : 'cV[R]_m) : u^~ 0 =1 v^~ 0 u = v.

Lemma row_const i0 a : row i0 (const_mx a) = const_mx a.

Lemma col_const j0 a : col j0 (const_mx a) = const_mx a.

Lemma row'_const i0 a : row' i0 (const_mx a) = const_mx a.

Lemma col'_const j0 a : col' j0 (const_mx a) = const_mx a.

Lemma col_perm1 A : col_perm 1 A = A.

Lemma row_perm1 A : row_perm 1 A = A.

Lemma col_permM s t A : col_perm (s × t) A = col_perm s (col_perm t A).

Lemma row_permM s t A : row_perm (s × t) A = row_perm s (row_perm t A).

Lemma col_row_permC s t A :
  col_perm s (row_perm t A) = row_perm t (col_perm s A).

End FixedDim.


Lemma castmx_id m n erefl_mn (A : 'M_(m, n)) : castmx erefl_mn A = A.

Lemma castmx_comp m1 n1 m2 n2 m3 n3 (eq_m1 : m1 = m2) (eq_n1 : n1 = n2)
                                    (eq_m2 : m2 = m3) (eq_n2 : n2 = n3) A :
  castmx (eq_m2, eq_n2) (castmx (eq_m1, eq_n1) A)
    = castmx (etrans eq_m1 eq_m2, etrans eq_n1 eq_n2) A.

Lemma castmxK m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) :
  cancel (castmx (eq_m, eq_n)) (castmx (esym eq_m, esym eq_n)).

Lemma castmxKV m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) :
  cancel (castmx (esym eq_m, esym eq_n)) (castmx (eq_m, eq_n)).

This can be use to reverse an equation that involves a cast.
Lemma castmx_sym m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) A1 A2 :
  A1 = castmx (eq_m, eq_n) A2 A2 = castmx (esym eq_m, esym eq_n) A1.

Lemma castmxE m1 n1 m2 n2 (eq_mn : (m1 = m2) × (n1 = n2)) A i j :
  castmx eq_mn A i j =
     A (cast_ord (esym eq_mn.1) i) (cast_ord (esym eq_mn.2) j).

Lemma conform_mx_id m n (B A : 'M_(m, n)) : conform_mx B A = A.

Lemma nonconform_mx m m' n n' (B : 'M_(m', n')) (A : 'M_(m, n)) :
  (m != m') || (n != n') conform_mx B A = B.

Lemma conform_castmx m1 n1 m2 n2 m3 n3
                     (e_mn : (m2 = m3) × (n2 = n3)) (B : 'M_(m1, n1)) A :
  conform_mx B (castmx e_mn A) = conform_mx B A.

Lemma trmxK m n : cancel (@trmx m n) (@trmx n m).

Lemma trmx_inj m n : injective (@trmx m n).

Lemma trmx_cast m1 n1 m2 n2 (eq_mn : (m1 = m2) × (n1 = n2)) A :
  (castmx eq_mn A)^T = castmx (eq_mn.2, eq_mn.1) A^T.

Lemma tr_row_perm m n s (A : 'M_(m, n)) : (row_perm s A)^T = col_perm s A^T.

Lemma tr_col_perm m n s (A : 'M_(m, n)) : (col_perm s A)^T = row_perm s A^T.

Lemma tr_xrow m n i1 i2 (A : 'M_(m, n)) : (xrow i1 i2 A)^T = xcol i1 i2 A^T.

Lemma tr_xcol m n j1 j2 (A : 'M_(m, n)) : (xcol j1 j2 A)^T = xrow j1 j2 A^T.

Lemma row_id n i (V : 'rV_n) : row i V = V.

Lemma col_id n j (V : 'cV_n) : col j V = V.

Lemma row_eq m1 m2 n i1 i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row i1 A1 = row i2 A2 A1 i1 =1 A2 i2.

Lemma col_eq m n1 n2 j1 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col j1 A1 = col j2 A2 A1^~ j1 =1 A2^~ j2.

Lemma row'_eq m n i0 (A B : 'M_(m, n)) :
  row' i0 A = row' i0 B {in predC1 i0, A =2 B}.

Lemma col'_eq m n j0 (A B : 'M_(m, n)) :
  col' j0 A = col' j0 B i, {in predC1 j0, A i =1 B i}.

Lemma tr_row m n i0 (A : 'M_(m, n)) : (row i0 A)^T = col i0 A^T.

Lemma tr_row' m n i0 (A : 'M_(m, n)) : (row' i0 A)^T = col' i0 A^T.

Lemma tr_col m n j0 (A : 'M_(m, n)) : (col j0 A)^T = row j0 A^T.

Lemma tr_col' m n j0 (A : 'M_(m, n)) : (col' j0 A)^T = row' j0 A^T.

Ltac split_mxE := apply/matrixPi j; do ![rewrite mxE | case: split ⇒ ?].

Section CutPaste.

Variables m m1 m2 n n1 n2 : nat.

Concatenating two matrices, in either direction.

Fact row_mx_key : unit.
Definition row_mx (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) : 'M[R]_(m, n1 + n2) :=
  \matrix[row_mx_key]_(i, j)
     match split j with inl j1A1 i j1 | inr j2A2 i j2 end.

Fact col_mx_key : unit.
Definition col_mx (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) : 'M[R]_(m1 + m2, n) :=
  \matrix[col_mx_key]_(i, j)
     match split i with inl i1A1 i1 j | inr i2A2 i2 j end.

Left/Right | Up/Down submatrices of a rows | columns matrix. The shape of the (dependent) width parameters of the type of A determines which submatrix is selected.

Fact lsubmx_key : unit.
Definition lsubmx (A : 'M[R]_(m, n1 + n2)) :=
  \matrix[lsubmx_key]_(i, j) A i (lshift n2 j).

Fact rsubmx_key : unit.
Definition rsubmx (A : 'M[R]_(m, n1 + n2)) :=
  \matrix[rsubmx_key]_(i, j) A i (rshift n1 j).

Fact usubmx_key : unit.
Definition usubmx (A : 'M[R]_(m1 + m2, n)) :=
  \matrix[usubmx_key]_(i, j) A (lshift m2 i) j.

Fact dsubmx_key : unit.
Definition dsubmx (A : 'M[R]_(m1 + m2, n)) :=
  \matrix[dsubmx_key]_(i, j) A (rshift m1 i) j.

Lemma row_mxEl A1 A2 i j : row_mx A1 A2 i (lshift n2 j) = A1 i j.

Lemma row_mxKl A1 A2 : lsubmx (row_mx A1 A2) = A1.

Lemma row_mxEr A1 A2 i j : row_mx A1 A2 i (rshift n1 j) = A2 i j.

Lemma row_mxKr A1 A2 : rsubmx (row_mx A1 A2) = A2.

Lemma hsubmxK A : row_mx (lsubmx A) (rsubmx A) = A.

Lemma col_mxEu A1 A2 i j : col_mx A1 A2 (lshift m2 i) j = A1 i j.

Lemma col_mxKu A1 A2 : usubmx (col_mx A1 A2) = A1.

Lemma col_mxEd A1 A2 i j : col_mx A1 A2 (rshift m1 i) j = A2 i j.

Lemma col_mxKd A1 A2 : dsubmx (col_mx A1 A2) = A2.

Lemma eq_row_mx A1 A2 B1 B2 : row_mx A1 A2 = row_mx B1 B2 A1 = B1 A2 = B2.

Lemma eq_col_mx A1 A2 B1 B2 : col_mx A1 A2 = col_mx B1 B2 A1 = B1 A2 = B2.

Lemma row_mx_const a : row_mx (const_mx a) (const_mx a) = const_mx a.

Lemma col_mx_const a : col_mx (const_mx a) (const_mx a) = const_mx a.

End CutPaste.

Lemma trmx_lsub m n1 n2 (A : 'M_(m, n1 + n2)) : (lsubmx A)^T = usubmx A^T.

Lemma trmx_rsub m n1 n2 (A : 'M_(m, n1 + n2)) : (rsubmx A)^T = dsubmx A^T.

Lemma tr_row_mx m n1 n2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  (row_mx A1 A2)^T = col_mx A1^T A2^T.

Lemma tr_col_mx m1 m2 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  (col_mx A1 A2)^T = row_mx A1^T A2^T.

Lemma trmx_usub m1 m2 n (A : 'M_(m1 + m2, n)) : (usubmx A)^T = lsubmx A^T.

Lemma trmx_dsub m1 m2 n (A : 'M_(m1 + m2, n)) : (dsubmx A)^T = rsubmx A^T.

Lemma vsubmxK m1 m2 n (A : 'M_(m1 + m2, n)) : col_mx (usubmx A) (dsubmx A) = A.

Lemma cast_row_mx m m' n1 n2 (eq_m : m = m') A1 A2 :
  castmx (eq_m, erefl _) (row_mx A1 A2)
    = row_mx (castmx (eq_m, erefl n1) A1) (castmx (eq_m, erefl n2) A2).

Lemma cast_col_mx m1 m2 n n' (eq_n : n = n') A1 A2 :
  castmx (erefl _, eq_n) (col_mx A1 A2)
    = col_mx (castmx (erefl m1, eq_n) A1) (castmx (erefl m2, eq_n) A2).

This lemma has Prenex Implicits to help RL rewrititng with castmx_sym.
Lemma row_mxA m n1 n2 n3 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) (A3 : 'M_(m, n3)) :
  let cast := (erefl m, esym (addnA n1 n2 n3)) in
  row_mx A1 (row_mx A2 A3) = castmx cast (row_mx (row_mx A1 A2) A3).
Definition row_mxAx := row_mxA. (* bypass Prenex Implicits. *)

This lemma has Prenex Implicits to help RL rewrititng with castmx_sym.
Lemma col_mxA m1 m2 m3 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) (A3 : 'M_(m3, n)) :
  let cast := (esym (addnA m1 m2 m3), erefl n) in
  col_mx A1 (col_mx A2 A3) = castmx cast (col_mx (col_mx A1 A2) A3).
Definition col_mxAx := col_mxA. (* bypass Prenex Implicits. *)

Lemma row_row_mx m n1 n2 i0 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  row i0 (row_mx A1 A2) = row_mx (row i0 A1) (row i0 A2).

Lemma col_col_mx m1 m2 n j0 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  col j0 (col_mx A1 A2) = col_mx (col j0 A1) (col j0 A2).

Lemma row'_row_mx m n1 n2 i0 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  row' i0 (row_mx A1 A2) = row_mx (row' i0 A1) (row' i0 A2).

Lemma col'_col_mx m1 m2 n j0 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  col' j0 (col_mx A1 A2) = col_mx (col' j0 A1) (col' j0 A2).

Lemma colKl m n1 n2 j1 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col (lshift n2 j1) (row_mx A1 A2) = col j1 A1.

Lemma colKr m n1 n2 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col (rshift n1 j2) (row_mx A1 A2) = col j2 A2.

Lemma rowKu m1 m2 n i1 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row (lshift m2 i1) (col_mx A1 A2) = row i1 A1.

Lemma rowKd m1 m2 n i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row (rshift m1 i2) (col_mx A1 A2) = row i2 A2.

Lemma col'Kl m n1 n2 j1 (A1 : 'M_(m, n1.+1)) (A2 : 'M_(m, n2)) :
  col' (lshift n2 j1) (row_mx A1 A2) = row_mx (col' j1 A1) A2.

Lemma row'Ku m1 m2 n i1 (A1 : 'M_(m1.+1, n)) (A2 : 'M_(m2, n)) :
  row' (lshift m2 i1) (@col_mx m1.+1 m2 n A1 A2) = col_mx (row' i1 A1) A2.

Lemma mx'_cast m n : 'I_n (m + n.-1)%N = (m + n).-1.

Lemma col'Kr m n1 n2 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col' (rshift n1 j2) (@row_mx m n1 n2 A1 A2)
    = castmx (erefl m, mx'_cast n1 j2) (row_mx A1 (col' j2 A2)).

Lemma row'Kd m1 m2 n i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row' (rshift m1 i2) (col_mx A1 A2)
    = castmx (mx'_cast m1 i2, erefl n) (col_mx A1 (row' i2 A2)).

Section Block.

Variables m1 m2 n1 n2 : nat.

Building a block matrix from 4 matrices : up left, up right, down left and down right components

Definition block_mx Aul Aur Adl Adr : 'M_(m1 + m2, n1 + n2) :=
  col_mx (row_mx Aul Aur) (row_mx Adl Adr).

Lemma eq_block_mx Aul Aur Adl Adr Bul Bur Bdl Bdr :
 block_mx Aul Aur Adl Adr = block_mx Bul Bur Bdl Bdr
  [/\ Aul = Bul, Aur = Bur, Adl = Bdl & Adr = Bdr].

Lemma block_mx_const a :
  block_mx (const_mx a) (const_mx a) (const_mx a) (const_mx a) = const_mx a.

Section CutBlock.

Variable A : matrix R (m1 + m2) (n1 + n2).

Definition ulsubmx := lsubmx (usubmx A).
Definition ursubmx := rsubmx (usubmx A).
Definition dlsubmx := lsubmx (dsubmx A).
Definition drsubmx := rsubmx (dsubmx A).

Lemma submxK : block_mx ulsubmx ursubmx dlsubmx drsubmx = A.

End CutBlock.

Section CatBlock.

Variables (Aul : 'M[R]_(m1, n1)) (Aur : 'M[R]_(m1, n2)).
Variables (Adl : 'M[R]_(m2, n1)) (Adr : 'M[R]_(m2, n2)).

Let A := block_mx Aul Aur Adl Adr.

Lemma block_mxEul i j : A (lshift m2 i) (lshift n2 j) = Aul i j.
Lemma block_mxKul : ulsubmx A = Aul.

Lemma block_mxEur i j : A (lshift m2 i) (rshift n1 j) = Aur i j.
Lemma block_mxKur : ursubmx A = Aur.

Lemma block_mxEdl i j : A (rshift m1 i) (lshift n2 j) = Adl i j.
Lemma block_mxKdl : dlsubmx A = Adl.

Lemma block_mxEdr i j : A (rshift m1 i) (rshift n1 j) = Adr i j.
Lemma block_mxKdr : drsubmx A = Adr.

Lemma block_mxEv : A = col_mx (row_mx Aul Aur) (row_mx Adl Adr).

End CatBlock.

End Block.

Section TrCutBlock.

Variables m1 m2 n1 n2 : nat.
Variable A : 'M[R]_(m1 + m2, n1 + n2).

Lemma trmx_ulsub : (ulsubmx A)^T = ulsubmx A^T.

Lemma trmx_ursub : (ursubmx A)^T = dlsubmx A^T.

Lemma trmx_dlsub : (dlsubmx A)^T = ursubmx A^T.

Lemma trmx_drsub : (drsubmx A)^T = drsubmx A^T.

End TrCutBlock.

Section TrBlock.
Variables m1 m2 n1 n2 : nat.
Variables (Aul : 'M[R]_(m1, n1)) (Aur : 'M[R]_(m1, n2)).
Variables (Adl : 'M[R]_(m2, n1)) (Adr : 'M[R]_(m2, n2)).

Lemma tr_block_mx :
 (block_mx Aul Aur Adl Adr)^T = block_mx Aul^T Adl^T Aur^T Adr^T.

Lemma block_mxEh :
  block_mx Aul Aur Adl Adr = row_mx (col_mx Aul Adl) (col_mx Aur Adr).
End TrBlock.

This lemma has Prenex Implicits to help RL rewrititng with castmx_sym.
Lemma block_mxA m1 m2 m3 n1 n2 n3
   (A11 : 'M_(m1, n1)) (A12 : 'M_(m1, n2)) (A13 : 'M_(m1, n3))
   (A21 : 'M_(m2, n1)) (A22 : 'M_(m2, n2)) (A23 : 'M_(m2, n3))
   (A31 : 'M_(m3, n1)) (A32 : 'M_(m3, n2)) (A33 : 'M_(m3, n3)) :
  let cast := (esym (addnA m1 m2 m3), esym (addnA n1 n2 n3)) in
  let row1 := row_mx A12 A13 in let col1 := col_mx A21 A31 in
  let row3 := row_mx A31 A32 in let col3 := col_mx A13 A23 in
  block_mx A11 row1 col1 (block_mx A22 A23 A32 A33)
    = castmx cast (block_mx (block_mx A11 A12 A21 A22) col3 row3 A33).
Definition block_mxAx := block_mxA. (* Bypass Prenex Implicits *)

Bijections mxvec : 'M(m, n) <----> 'rV(m * n) : vec_mx
Section VecMatrix.

Variables m n : nat.

Lemma mxvec_cast : #|{:'I_m × 'I_n}| = (m × n)%N.

Definition mxvec_index (i : 'I_m) (j : 'I_n) :=
  cast_ord mxvec_cast (enum_rank (i, j)).

Variant is_mxvec_index : 'I_(m × n) Type :=
  IsMxvecIndex i j : is_mxvec_index (mxvec_index i j).

Lemma mxvec_indexP k : is_mxvec_index k.

Coercion pair_of_mxvec_index k (i_k : is_mxvec_index k) :=
  let: IsMxvecIndex i j := i_k in (i, j).

Definition mxvec (A : 'M[R]_(m, n)) :=
  castmx (erefl _, mxvec_cast) (\row_k A (enum_val k).1 (enum_val k).2).

Fact vec_mx_key : unit.
Definition vec_mx (u : 'rV[R]_(m × n)) :=
  \matrix[vec_mx_key]_(i, j) u 0 (mxvec_index i j).

Lemma mxvecE A i j : mxvec A 0 (mxvec_index i j) = A i j.

Lemma mxvecK : cancel mxvec vec_mx.

Lemma vec_mxK : cancel vec_mx mxvec.

Lemma curry_mxvec_bij : {on 'I_(m × n), bijective (prod_curry mxvec_index)}.

End VecMatrix.

End MatrixStructural.


Notation "A ^T" := (trmx A) : ring_scope.

Matrix parametricity.
Section MapMatrix.

Variables (aT rT : Type) (f : aT rT).

Fact map_mx_key : unit.
Definition map_mx m n (A : 'M_(m, n)) := \matrix[map_mx_key]_(i, j) f (A i j).

Notation "A ^f" := (map_mx A) : ring_scope.

Section OneMatrix.

Variables (m n : nat) (A : 'M[aT]_(m, n)).

Lemma map_trmx : A^f^T = A^T^f.

Lemma map_const_mx a : (const_mx a)^f = const_mx (f a) :> 'M_(m, n).

Lemma map_row i : (row i A)^f = row i A^f.

Lemma map_col j : (col j A)^f = col j A^f.

Lemma map_row' i0 : (row' i0 A)^f = row' i0 A^f.

Lemma map_col' j0 : (col' j0 A)^f = col' j0 A^f.

Lemma map_row_perm s : (row_perm s A)^f = row_perm s A^f.

Lemma map_col_perm s : (col_perm s A)^f = col_perm s A^f.

Lemma map_xrow i1 i2 : (xrow i1 i2 A)^f = xrow i1 i2 A^f.

Lemma map_xcol j1 j2 : (xcol j1 j2 A)^f = xcol j1 j2 A^f.

Lemma map_castmx m' n' c : (castmx c A)^f = castmx c A^f :> 'M_(m', n').

Lemma map_conform_mx m' n' (B : 'M_(m', n')) :
  (conform_mx B A)^f = conform_mx B^f A^f.

Lemma map_mxvec : (mxvec A)^f = mxvec A^f.

Lemma map_vec_mx (v : 'rV_(m × n)) : (vec_mx v)^f = vec_mx v^f.

End OneMatrix.

Section Block.

Variables m1 m2 n1 n2 : nat.
Variables (Aul : 'M[aT]_(m1, n1)) (Aur : 'M[aT]_(m1, n2)).
Variables (Adl : 'M[aT]_(m2, n1)) (Adr : 'M[aT]_(m2, n2)).
Variables (Bh : 'M[aT]_(m1, n1 + n2)) (Bv : 'M[aT]_(m1 + m2, n1)).
Variable B : 'M[aT]_(m1 + m2, n1 + n2).

Lemma map_row_mx : (row_mx Aul Aur)^f = row_mx Aul^f Aur^f.

Lemma map_col_mx : (col_mx Aul Adl)^f = col_mx Aul^f Adl^f.

Lemma map_block_mx :
  (block_mx Aul Aur Adl Adr)^f = block_mx Aul^f Aur^f Adl^f Adr^f.

Lemma map_lsubmx : (lsubmx Bh)^f = lsubmx Bh^f.

Lemma map_rsubmx : (rsubmx Bh)^f = rsubmx Bh^f.

Lemma map_usubmx : (usubmx Bv)^f = usubmx Bv^f.

Lemma map_dsubmx : (dsubmx Bv)^f = dsubmx Bv^f.

Lemma map_ulsubmx : (ulsubmx B)^f = ulsubmx B^f.

Lemma map_ursubmx : (ursubmx B)^f = ursubmx B^f.

Lemma map_dlsubmx : (dlsubmx B)^f = dlsubmx B^f.

Lemma map_drsubmx : (drsubmx B)^f = drsubmx B^f.

End Block.

End MapMatrix.


Matrix Zmodule (additive) structure ******************

Section MatrixZmodule.

Variable V : zmodType.

Section FixedDim.

Variables m n : nat.
Implicit Types A B : 'M[V]_(m, n).

Fact oppmx_key : unit.
Fact addmx_key : unit.
Definition oppmx A := \matrix[oppmx_key]_(i, j) (- A i j).
Definition addmx A B := \matrix[addmx_key]_(i, j) (A i j + B i j).
In principle, diag_mx and scalar_mx could be defined here, but since they only make sense with the graded ring operations, we defer them to the next section.

Lemma addmxA : associative addmx.

Lemma addmxC : commutative addmx.

Lemma add0mx : left_id (const_mx 0) addmx.

Lemma addNmx : left_inverse (const_mx 0) oppmx addmx.

Definition matrix_zmodMixin := ZmodMixin addmxA addmxC add0mx addNmx.

Canonical matrix_zmodType := Eval hnf in ZmodType 'M[V]_(m, n) matrix_zmodMixin.

Lemma mulmxnE A d i j : (A *+ d) i j = A i j *+ d.

Lemma summxE I r (P : pred I) (E : I 'M_(m, n)) i j :
  (\sum_(k <- r | P k) E k) i j = \sum_(k <- r | P k) E k i j.

Lemma const_mx_is_additive : additive const_mx.
Canonical const_mx_additive := Additive const_mx_is_additive.

End FixedDim.

Section Additive.

Variables (m n p q : nat) (f : 'I_p 'I_q 'I_m) (g : 'I_p 'I_q 'I_n).

Definition swizzle_mx k (A : 'M[V]_(m, n)) :=
  \matrix[k]_(i, j) A (f i j) (g i j).

Lemma swizzle_mx_is_additive k : additive (swizzle_mx k).
Canonical swizzle_mx_additive k := Additive (swizzle_mx_is_additive k).

End Additive.


Canonical trmx_additive m n := SwizzleAdd (@trmx V m n).
Canonical row_additive m n i := SwizzleAdd (@row V m n i).
Canonical col_additive m n j := SwizzleAdd (@col V m n j).
Canonical row'_additive m n i := SwizzleAdd (@row' V m n i).
Canonical col'_additive m n j := SwizzleAdd (@col' V m n j).
Canonical row_perm_additive m n s := SwizzleAdd (@row_perm V m n s).
Canonical col_perm_additive m n s := SwizzleAdd (@col_perm V m n s).
Canonical xrow_additive m n i1 i2 := SwizzleAdd (@xrow V m n i1 i2).
Canonical xcol_additive m n j1 j2 := SwizzleAdd (@xcol V m n j1 j2).
Canonical lsubmx_additive m n1 n2 := SwizzleAdd (@lsubmx V m n1 n2).
Canonical rsubmx_additive m n1 n2 := SwizzleAdd (@rsubmx V m n1 n2).
Canonical usubmx_additive m1 m2 n := SwizzleAdd (@usubmx V m1 m2 n).
Canonical dsubmx_additive m1 m2 n := SwizzleAdd (@dsubmx V m1 m2 n).
Canonical vec_mx_additive m n := SwizzleAdd (@vec_mx V m n).
Canonical mxvec_additive m n :=
  Additive (can2_additive (@vec_mxK V m n) mxvecK).

Lemma flatmx0 n : all_equal_to (0 : 'M_(0, n)).

Lemma thinmx0 n : all_equal_to (0 : 'M_(n, 0)).

Lemma trmx0 m n : (0 : 'M_(m, n))^T = 0.

Lemma row0 m n i0 : row i0 (0 : 'M_(m, n)) = 0.

Lemma col0 m n j0 : col j0 (0 : 'M_(m, n)) = 0.

Lemma mxvec_eq0 m n (A : 'M_(m, n)) : (mxvec A == 0) = (A == 0).

Lemma vec_mx_eq0 m n (v : 'rV_(m × n)) : (vec_mx v == 0) = (v == 0).

Lemma row_mx0 m n1 n2 : row_mx 0 0 = 0 :> 'M_(m, n1 + n2).

Lemma col_mx0 m1 m2 n : col_mx 0 0 = 0 :> 'M_(m1 + m2, n).

Lemma block_mx0 m1 m2 n1 n2 : block_mx 0 0 0 0 = 0 :> 'M_(m1 + m2, n1 + n2).

Ltac split_mxE := apply/matrixPi j; do ![rewrite mxE | case: split ⇒ ?].

Lemma opp_row_mx m n1 n2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  - row_mx A1 A2 = row_mx (- A1) (- A2).

Lemma opp_col_mx m1 m2 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  - col_mx A1 A2 = col_mx (- A1) (- A2).

Lemma opp_block_mx m1 m2 n1 n2 (Aul : 'M_(m1, n1)) Aur Adl (Adr : 'M_(m2, n2)) :
  - block_mx Aul Aur Adl Adr = block_mx (- Aul) (- Aur) (- Adl) (- Adr).

Lemma add_row_mx m n1 n2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) B1 B2 :
  row_mx A1 A2 + row_mx B1 B2 = row_mx (A1 + B1) (A2 + B2).

Lemma add_col_mx m1 m2 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) B1 B2 :
  col_mx A1 A2 + col_mx B1 B2 = col_mx (A1 + B1) (A2 + B2).

Lemma add_block_mx m1 m2 n1 n2 (Aul : 'M_(m1, n1)) Aur Adl (Adr : 'M_(m2, n2))
                   Bul Bur Bdl Bdr :
  let A := block_mx Aul Aur Adl Adr in let B := block_mx Bul Bur Bdl Bdr in
  A + B = block_mx (Aul + Bul) (Aur + Bur) (Adl + Bdl) (Adr + Bdr).

Lemma row_mx_eq0 (m n1 n2 : nat) (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)):
  (row_mx A1 A2 == 0) = (A1 == 0) && (A2 == 0).

Lemma col_mx_eq0 (m1 m2 n : nat) (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)):
  (col_mx A1 A2 == 0) = (A1 == 0) && (A2 == 0).

Lemma block_mx_eq0 m1 m2 n1 n2 (Aul : 'M_(m1, n1)) Aur Adl (Adr : 'M_(m2, n2)) :
  (block_mx Aul Aur Adl Adr == 0) =
  [&& Aul == 0, Aur == 0, Adl == 0 & Adr == 0].

Definition nz_row m n (A : 'M_(m, n)) :=
  oapp (fun irow i A) 0 [pick i | row i A != 0].

Lemma nz_row_eq0 m n (A : 'M_(m, n)) : (nz_row A == 0) = (A == 0).

End MatrixZmodule.

Section FinZmodMatrix.
Variables (V : finZmodType) (m n : nat).

Canonical matrix_finZmodType := Eval hnf in [finZmodType of MV].
Canonical matrix_baseFinGroupType :=
  Eval hnf in [baseFinGroupType of MV for +%R].
Canonical matrix_finGroupType := Eval hnf in [finGroupType of MV for +%R].
End FinZmodMatrix.

Parametricity over the additive structure.
Section MapZmodMatrix.

Variables (aR rR : zmodType) (f : {additive aR rR}) (m n : nat).
Implicit Type A : 'M[aR]_(m, n).

Lemma map_mx0 : 0^f = 0 :> 'M_(m, n).

Lemma map_mxN A : (- A)^f = - A^f.

Lemma map_mxD A B : (A + B)^f = A^f + B^f.

Lemma map_mx_sub A B : (A - B)^f = A^f - B^f.

Definition map_mx_sum := big_morph _ map_mxD map_mx0.

Canonical map_mx_additive := Additive map_mx_sub.

End MapZmodMatrix.

Matrix ring module, graded ring, and ring structures ***********

Section MatrixAlgebra.

Variable R : ringType.

Section RingModule.

The ring module/vector space structure

Variables m n : nat.
Implicit Types A B : 'M[R]_(m, n).

Fact scalemx_key : unit.
Definition scalemx x A := \matrix[scalemx_key]_(i, j) (x × A i j).

Basis
Fact delta_mx_key : unit.
Definition delta_mx i0 j0 : 'M[R]_(m, n) :=
  \matrix[delta_mx_key]_(i, j) ((i == i0) && (j == j0))%:R.


Lemma scale1mx A : 1 ×m: A = A.

Lemma scalemxDl A x y : (x + y) ×m: A = x ×m: A + y ×m: A.

Lemma scalemxDr x A B : x ×m: (A + B) = x ×m: A + x ×m: B.

Lemma scalemxA x y A : x ×m: (y ×m: A) = (x × y) ×m: A.

Definition matrix_lmodMixin :=
  LmodMixin scalemxA scale1mx scalemxDr scalemxDl.

Canonical matrix_lmodType :=
  Eval hnf in LmodType R 'M[R]_(m, n) matrix_lmodMixin.

Lemma scalemx_const a b : a *: const_mx b = const_mx (a × b).

Lemma matrix_sum_delta A :
  A = \sum_(i < m) \sum_(j < n) A i j *: delta_mx i j.

End RingModule.

Section StructuralLinear.

Lemma swizzle_mx_is_scalable m n p q f g k :
  scalable (@swizzle_mx R m n p q f g k).
Canonical swizzle_mx_scalable m n p q f g k :=
  AddLinear (@swizzle_mx_is_scalable m n p q f g k).


Canonical trmx_linear m n := SwizzleLin (@trmx R m n).
Canonical row_linear m n i := SwizzleLin (@row R m n i).
Canonical col_linear m n j := SwizzleLin (@col R m n j).
Canonical row'_linear m n i := SwizzleLin (@row' R m n i).
Canonical col'_linear m n j := SwizzleLin (@col' R m n j).
Canonical row_perm_linear m n s := SwizzleLin (@row_perm R m n s).
Canonical col_perm_linear m n s := SwizzleLin (@col_perm R m n s).
Canonical xrow_linear m n i1 i2 := SwizzleLin (@xrow R m n i1 i2).
Canonical xcol_linear m n j1 j2 := SwizzleLin (@xcol R m n j1 j2).
Canonical lsubmx_linear m n1 n2 := SwizzleLin (@lsubmx R m n1 n2).
Canonical rsubmx_linear m n1 n2 := SwizzleLin (@rsubmx R m n1 n2).
Canonical usubmx_linear m1 m2 n := SwizzleLin (@usubmx R m1 m2 n).
Canonical dsubmx_linear m1 m2 n := SwizzleLin (@dsubmx R m1 m2 n).
Canonical vec_mx_linear m n := SwizzleLin (@vec_mx R m n).
Definition mxvec_is_linear m n := can2_linear (@vec_mxK R m n) mxvecK.
Canonical mxvec_linear m n := AddLinear (@mxvec_is_linear m n).

End StructuralLinear.

Lemma trmx_delta m n i j : (delta_mx i j)^T = delta_mx j i :> 'M[R]_(n, m).

Lemma row_sum_delta n (u : 'rV_n) : u = \sum_(j < n) u 0 j *: delta_mx 0 j.

Lemma delta_mx_lshift m n1 n2 i j :
  delta_mx i (lshift n2 j) = row_mx (delta_mx i j) 0 :> 'M_(m, n1 + n2).

Lemma delta_mx_rshift m n1 n2 i j :
  delta_mx i (rshift n1 j) = row_mx 0 (delta_mx i j) :> 'M_(m, n1 + n2).

Lemma delta_mx_ushift m1 m2 n i j :
  delta_mx (lshift m2 i) j = col_mx (delta_mx i j) 0 :> 'M_(m1 + m2, n).

Lemma delta_mx_dshift m1 m2 n i j :
  delta_mx (rshift m1 i) j = col_mx 0 (delta_mx i j) :> 'M_(m1 + m2, n).

Lemma vec_mx_delta m n i j :
  vec_mx (delta_mx 0 (mxvec_index i j)) = delta_mx i j :> 'M_(m, n).

Lemma mxvec_delta m n i j :
  mxvec (delta_mx i j) = delta_mx 0 (mxvec_index i j) :> 'rV_(m × n).

Ltac split_mxE := apply/matrixPi j; do ![rewrite mxE | case: split ⇒ ?].

Lemma scale_row_mx m n1 n2 a (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  a *: row_mx A1 A2 = row_mx (a *: A1) (a *: A2).

Lemma scale_col_mx m1 m2 n a (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  a *: col_mx A1 A2 = col_mx (a *: A1) (a *: A2).

Lemma scale_block_mx m1 m2 n1 n2 a (Aul : 'M_(m1, n1)) (Aur : 'M_(m1, n2))
                                   (Adl : 'M_(m2, n1)) (Adr : 'M_(m2, n2)) :
  a *: block_mx Aul Aur Adl Adr
     = block_mx (a *: Aul) (a *: Aur) (a *: Adl) (a *: Adr).

Diagonal matrices

Fact diag_mx_key : unit.
Definition diag_mx n (d : 'rV[R]_n) :=
  \matrix[diag_mx_key]_(i, j) (d 0 i *+ (i == j)).

Lemma tr_diag_mx n (d : 'rV_n) : (diag_mx d)^T = diag_mx d.

Lemma diag_mx_is_linear n : linear (@diag_mx n).
Canonical diag_mx_additive n := Additive (@diag_mx_is_linear n).
Canonical diag_mx_linear n := Linear (@diag_mx_is_linear n).

Lemma diag_mx_sum_delta n (d : 'rV_n) :
  diag_mx d = \sum_i d 0 i *: delta_mx i i.

Scalar matrix : a diagonal matrix with a constant on the diagonal
Section ScalarMx.

Variable n : nat.

Fact scalar_mx_key : unit.
Definition scalar_mx x : 'M[R]_n :=
  \matrix[scalar_mx_key]_(i , j) (x *+ (i == j)).
Notation "x %:M" := (scalar_mx x) : ring_scope.

Lemma diag_const_mx a : diag_mx (const_mx a) = a%:M :> 'M_n.

Lemma tr_scalar_mx a : (a%:M)^T = a%:M.

Lemma trmx1 : (1%:M)^T = 1%:M.

Lemma scalar_mx_is_additive : additive scalar_mx.
Canonical scalar_mx_additive := Additive scalar_mx_is_additive.

Lemma scale_scalar_mx a1 a2 : a1 *: a2%:M = (a1 × a2)%:M :> 'M_n.

Lemma scalemx1 a : a *: 1%:M = a%:M.

Lemma scalar_mx_sum_delta a : a%:M = \sum_i a *: delta_mx i i.

Lemma mx1_sum_delta : 1%:M = \sum_i delta_mx i i.

Lemma row1 i : row i 1%:M = delta_mx 0 i.

Definition is_scalar_mx (A : 'M[R]_n) :=
  if insub 0%N is Some i then A == (A i i)%:M else true.

Lemma is_scalar_mxP A : reflect ( a, A = a%:M) (is_scalar_mx A).

Lemma scalar_mx_is_scalar a : is_scalar_mx a%:M.

Lemma mx0_is_scalar : is_scalar_mx 0.

End ScalarMx.

Notation "x %:M" := (scalar_mx _ x) : ring_scope.

Lemma mx11_scalar (A : 'M_1) : A = (A 0 0)%:M.

Lemma scalar_mx_block n1 n2 a : a%:M = block_mx a%:M 0 0 a%:M :> 'M_(n1 + n2).

Matrix multiplication using bigops.
Fact mulmx_key : unit.
Definition mulmx {m n p} (A : 'M_(m, n)) (B : 'M_(n, p)) : 'M[R]_(m, p) :=
  \matrix[mulmx_key]_(i, k) \sum_j (A i j × B j k).


Lemma mulmxA m n p q (A : 'M_(m, n)) (B : 'M_(n, p)) (C : 'M_(p, q)) :
  A ×m (B ×m C) = A ×m B ×m C.

Lemma mul0mx m n p (A : 'M_(n, p)) : 0 ×m A = 0 :> 'M_(m, p).

Lemma mulmx0 m n p (A : 'M_(m, n)) : A ×m 0 = 0 :> 'M_(m, p).

Lemma mulmxN m n p (A : 'M_(m, n)) (B : 'M_(n, p)) : A ×m (- B) = - (A ×m B).

Lemma mulNmx m n p (A : 'M_(m, n)) (B : 'M_(n, p)) : - A ×m B = - (A ×m B).

Lemma mulmxDl m n p (A1 A2 : 'M_(m, n)) (B : 'M_(n, p)) :
  (A1 + A2) ×m B = A1 ×m B + A2 ×m B.

Lemma mulmxDr m n p (A : 'M_(m, n)) (B1 B2 : 'M_(n, p)) :
  A ×m (B1 + B2) = A ×m B1 + A ×m B2.

Lemma mulmxBl m n p (A1 A2 : 'M_(m, n)) (B : 'M_(n, p)) :
  (A1 - A2) ×m B = A1 ×m B - A2 ×m B.

Lemma mulmxBr m n p (A : 'M_(m, n)) (B1 B2 : 'M_(n, p)) :
  A ×m (B1 - B2) = A ×m B1 - A ×m B2.

Lemma mulmx_suml m n p (A : 'M_(n, p)) I r P (B_ : I 'M_(m, n)) :
   (\sum_(i <- r | P i) B_ i) ×m A = \sum_(i <- r | P i) B_ i ×m A.

Lemma mulmx_sumr m n p (A : 'M_(m, n)) I r P (B_ : I 'M_(n, p)) :
   A ×m (\sum_(i <- r | P i) B_ i) = \sum_(i <- r | P i) A ×m B_ i.

Lemma scalemxAl m n p a (A : 'M_(m, n)) (B : 'M_(n, p)) :
  a *: (A ×m B) = (a *: A) ×m B.
Right scaling associativity requires a commutative ring

Lemma rowE m n i (A : 'M_(m, n)) : row i A = delta_mx 0 i ×m A.

Lemma row_mul m n p (i : 'I_m) A (B : 'M_(n, p)) :
  row i (A ×m B) = row i A ×m B.

Lemma mulmx_sum_row m n (u : 'rV_m) (A : 'M_(m, n)) :
  u ×m A = \sum_i u 0 i *: row i A.

Lemma mul_delta_mx_cond m n p (j1 j2 : 'I_n) (i1 : 'I_m) (k2 : 'I_p) :
  delta_mx i1 j1 ×m delta_mx j2 k2 = delta_mx i1 k2 *+ (j1 == j2).

Lemma mul_delta_mx m n p (j : 'I_n) (i : 'I_m) (k : 'I_p) :
  delta_mx i j ×m delta_mx j k = delta_mx i k.

Lemma mul_delta_mx_0 m n p (j1 j2 : 'I_n) (i1 : 'I_m) (k2 : 'I_p) :
  j1 != j2 delta_mx i1 j1 ×m delta_mx j2 k2 = 0.

Lemma mul_diag_mx m n d (A : 'M_(m, n)) :
  diag_mx d ×m A = \matrix_(i, j) (d 0 i × A i j).

Lemma mul_mx_diag m n (A : 'M_(m, n)) d :
  A ×m diag_mx d = \matrix_(i, j) (A i j × d 0 j).

Lemma mulmx_diag n (d e : 'rV_n) :
  diag_mx d ×m diag_mx e = diag_mx (\row_j (d 0 j × e 0 j)).

Lemma mul_scalar_mx m n a (A : 'M_(m, n)) : a%:M ×m A = a *: A.

Lemma scalar_mxM n a b : (a × b)%:M = a%:M ×m b%:M :> 'M_n.

Lemma mul1mx m n (A : 'M_(m, n)) : 1%:M ×m A = A.

Lemma mulmx1 m n (A : 'M_(m, n)) : A ×m 1%:M = A.

Lemma mul_col_perm m n p s (A : 'M_(m, n)) (B : 'M_(n, p)) :
  col_perm s A ×m B = A ×m row_perm s^-1 B.

Lemma mul_row_perm m n p s (A : 'M_(m, n)) (B : 'M_(n, p)) :
  A ×m row_perm s B = col_perm s^-1 A ×m B.

Lemma mul_xcol m n p j1 j2 (A : 'M_(m, n)) (B : 'M_(n, p)) :
  xcol j1 j2 A ×m B = A ×m xrow j1 j2 B.

Permutation matrix

Definition perm_mx n s : 'M_n := row_perm s 1%:M.

Definition tperm_mx n i1 i2 : 'M_n := perm_mx (tperm i1 i2).

Lemma col_permE m n s (A : 'M_(m, n)) : col_perm s A = A ×m perm_mx s^-1.

Lemma row_permE m n s (A : 'M_(m, n)) : row_perm s A = perm_mx s ×m A.

Lemma xcolE m n j1 j2 (A : 'M_(m, n)) : xcol j1 j2 A = A ×m tperm_mx j1 j2.

Lemma xrowE m n i1 i2 (A : 'M_(m, n)) : xrow i1 i2 A = tperm_mx i1 i2 ×m A.

Lemma tr_perm_mx n (s : 'S_n) : (perm_mx s)^T = perm_mx s^-1.

Lemma tr_tperm_mx n i1 i2 : (tperm_mx i1 i2)^T = tperm_mx i1 i2 :> 'M_n.

Lemma perm_mx1 n : perm_mx 1 = 1%:M :> 'M_n.

Lemma perm_mxM n (s t : 'S_n) : perm_mx (s × t) = perm_mx s ×m perm_mx t.

Definition is_perm_mx n (A : 'M_n) := [ s, A == perm_mx s].

Lemma is_perm_mxP n (A : 'M_n) :
  reflect ( s, A = perm_mx s) (is_perm_mx A).

Lemma perm_mx_is_perm n (s : 'S_n) : is_perm_mx (perm_mx s).

Lemma is_perm_mx1 n : is_perm_mx (1%:M : 'M_n).

Lemma is_perm_mxMl n (A B : 'M_n) :
  is_perm_mx A is_perm_mx (A ×m B) = is_perm_mx B.

Lemma is_perm_mx_tr n (A : 'M_n) : is_perm_mx A^T = is_perm_mx A.

Lemma is_perm_mxMr n (A B : 'M_n) :
  is_perm_mx B is_perm_mx (A ×m B) = is_perm_mx A.

Partial identity matrix (used in rank decomposition).

Fact pid_mx_key : unit.
Definition pid_mx {m n} r : 'M[R]_(m, n) :=
  \matrix[pid_mx_key]_(i, j) ((i == j :> nat) && (i < r))%:R.

Lemma pid_mx_0 m n : pid_mx 0 = 0 :> 'M_(m, n).

Lemma pid_mx_1 r : pid_mx r = 1%:M :> 'M_r.

Lemma pid_mx_row n r : pid_mx r = row_mx 1%:M 0 :> 'M_(r, r + n).

Lemma pid_mx_col m r : pid_mx r = col_mx 1%:M 0 :> 'M_(r + m, r).

Lemma pid_mx_block m n r : pid_mx r = block_mx 1%:M 0 0 0 :> 'M_(r + m, r + n).

Lemma tr_pid_mx m n r : (pid_mx r)^T = pid_mx r :> 'M_(n, m).

Lemma pid_mx_minv m n r : pid_mx (minn m r) = pid_mx r :> 'M_(m, n).

Lemma pid_mx_minh m n r : pid_mx (minn n r) = pid_mx r :> 'M_(m, n).

Lemma mul_pid_mx m n p q r :
  (pid_mx q : 'M_(m, n)) ×m (pid_mx r : 'M_(n, p)) = pid_mx (minn n (minn q r)).

Lemma pid_mx_id m n p r :
  r n (pid_mx r : 'M_(m, n)) ×m (pid_mx r : 'M_(n, p)) = pid_mx r.

Definition copid_mx {n} r : 'M_n := 1%:M - pid_mx r.

Lemma mul_copid_mx_pid m n r :
  r m copid_mx r ×m pid_mx r = 0 :> 'M_(m, n).

Lemma mul_pid_mx_copid m n r :
  r n pid_mx r ×m copid_mx r = 0 :> 'M_(m, n).

Lemma copid_mx_id n r :
  r n copid_mx r ×m copid_mx r = copid_mx r :> 'M_n.

Block products; we cover all 1 x 2, 2 x 1, and 2 x 2 block products.
Lemma mul_mx_row m n p1 p2 (A : 'M_(m, n)) (Bl : 'M_(n, p1)) (Br : 'M_(n, p2)) :
  A ×m row_mx Bl Br = row_mx (A ×m Bl) (A ×m Br).

Lemma mul_col_mx m1 m2 n p (Au : 'M_(m1, n)) (Ad : 'M_(m2, n)) (B : 'M_(n, p)) :
  col_mx Au Ad ×m B = col_mx (Au ×m B) (Ad ×m B).

Lemma mul_row_col m n1 n2 p (Al : 'M_(m, n1)) (Ar : 'M_(m, n2))
                            (Bu : 'M_(n1, p)) (Bd : 'M_(n2, p)) :
  row_mx Al Ar ×m col_mx Bu Bd = Al ×m Bu + Ar ×m Bd.

Lemma mul_col_row m1 m2 n p1 p2 (Au : 'M_(m1, n)) (Ad : 'M_(m2, n))
                                (Bl : 'M_(n, p1)) (Br : 'M_(n, p2)) :
  col_mx Au Ad ×m row_mx Bl Br
     = block_mx (Au ×m Bl) (Au ×m Br) (Ad ×m Bl) (Ad ×m Br).

Lemma mul_row_block m n1 n2 p1 p2 (Al : 'M_(m, n1)) (Ar : 'M_(m, n2))
                                  (Bul : 'M_(n1, p1)) (Bur : 'M_(n1, p2))
                                  (Bdl : 'M_(n2, p1)) (Bdr : 'M_(n2, p2)) :
  row_mx Al Ar ×m block_mx Bul Bur Bdl Bdr
   = row_mx (Al ×m Bul + Ar ×m Bdl) (Al ×m Bur + Ar ×m Bdr).

Lemma mul_block_col m1 m2 n1 n2 p (Aul : 'M_(m1, n1)) (Aur : 'M_(m1, n2))
                                  (Adl : 'M_(m2, n1)) (Adr : 'M_(m2, n2))
                                  (Bu : 'M_(n1, p)) (Bd : 'M_(n2, p)) :
  block_mx Aul Aur Adl Adr ×m col_mx Bu Bd
   = col_mx (Aul ×m Bu + Aur ×m Bd) (Adl ×m Bu + Adr ×m Bd).

Lemma mulmx_block m1 m2 n1 n2 p1 p2 (Aul : 'M_(m1, n1)) (Aur : 'M_(m1, n2))
                                    (Adl : 'M_(m2, n1)) (Adr : 'M_(m2, n2))
                                    (Bul : 'M_(n1, p1)) (Bur : 'M_(n1, p2))
                                    (Bdl : 'M_(n2, p1)) (Bdr : 'M_(n2, p2)) :
  block_mx Aul Aur Adl Adr ×m block_mx Bul Bur Bdl Bdr
    = block_mx (Aul ×m Bul + Aur ×m Bdl) (Aul ×m Bur + Aur ×m Bdr)
               (Adl ×m Bul + Adr ×m Bdl) (Adl ×m Bur + Adr ×m Bdr).

Correspondance between matrices and linear function on row vectors.
Section LinRowVector.

Variables m n : nat.

Fact lin1_mx_key : unit.
Definition lin1_mx (f : 'rV[R]_m 'rV[R]_n) :=
  \matrix[lin1_mx_key]_(i, j) f (delta_mx 0 i) 0 j.

Variable f : {linear 'rV[R]_m 'rV[R]_n}.

Lemma mul_rV_lin1 u : u ×m lin1_mx f = f u.

End LinRowVector.

Correspondance between matrices and linear function on matrices.
Section LinMatrix.

Variables m1 n1 m2 n2 : nat.

Definition lin_mx (f : 'M[R]_(m1, n1) 'M[R]_(m2, n2)) :=
  lin1_mx (mxvec \o f \o vec_mx).

Variable f : {linear 'M[R]_(m1, n1) 'M[R]_(m2, n2)}.

Lemma mul_rV_lin u : u ×m lin_mx f = mxvec (f (vec_mx u)).

Lemma mul_vec_lin A : mxvec A ×m lin_mx f = mxvec (f A).

Lemma mx_rV_lin u : vec_mx (u ×m lin_mx f) = f (vec_mx u).

Lemma mx_vec_lin A : vec_mx (mxvec A ×m lin_mx f) = f A.

End LinMatrix.

Canonical mulmx_additive m n p A := Additive (@mulmxBr m n p A).

Section Mulmxr.

Variables m n p : nat.
Implicit Type A : 'M[R]_(m, n).
Implicit Type B : 'M[R]_(n, p).

Definition mulmxr_head t B A := let: tt := t in A ×m B.

Definition lin_mulmxr B := lin_mx (mulmxr B).

Lemma mulmxr_is_linear B : linear (mulmxr B).
Canonical mulmxr_additive B := Additive (mulmxr_is_linear B).
Canonical mulmxr_linear B := Linear (mulmxr_is_linear B).

Lemma lin_mulmxr_is_linear : linear lin_mulmxr.
Canonical lin_mulmxr_additive := Additive lin_mulmxr_is_linear.
Canonical lin_mulmxr_linear := Linear lin_mulmxr_is_linear.

End Mulmxr.

The trace.
Section Trace.

Variable n : nat.

Definition mxtrace (A : 'M[R]_n) := \sum_i A i i.

Lemma mxtrace_tr A : \tr A^T = \tr A.

Lemma mxtrace_is_scalar : scalar mxtrace.
Canonical mxtrace_additive := Additive mxtrace_is_scalar.
Canonical mxtrace_linear := Linear mxtrace_is_scalar.

Lemma mxtrace0 : \tr 0 = 0.
Lemma mxtraceD A B : \tr (A + B) = \tr A + \tr B.
Lemma mxtraceZ a A : \tr (a *: A) = a × \tr A.

Lemma mxtrace_diag D : \tr (diag_mx D) = \sum_j D 0 j.

Lemma mxtrace_scalar a : \tr a%:M = a *+ n.

Lemma mxtrace1 : \tr 1%:M = n%:R.

End Trace.

Lemma trace_mx11 (A : 'M_1) : \tr A = A 0 0.

Lemma mxtrace_block n1 n2 (Aul : 'M_n1) Aur Adl (Adr : 'M_n2) :
  \tr (block_mx Aul Aur Adl Adr) = \tr Aul + \tr Adr.

The matrix ring structure requires a strutural condition (dimension of the form n.+1) to statisfy the nontriviality condition we have imposed.
Section MatrixRing.

Variable n' : nat.

Lemma matrix_nonzero1 : 1%:M != 0 :> 'M_n.

Definition matrix_ringMixin :=
  RingMixin (@mulmxA n n n n) (@mul1mx n n) (@mulmx1 n n)
            (@mulmxDl n n n) (@mulmxDr n n n) matrix_nonzero1.

Canonical matrix_ringType := Eval hnf in RingType 'M[R]_n matrix_ringMixin.
Canonical matrix_lAlgType := Eval hnf in LalgType R 'M[R]_n (@scalemxAl n n n).

Lemma mulmxE : mulmx = *%R.
Lemma idmxE : 1%:M = 1 :> 'M_n.

Lemma scalar_mx_is_multiplicative : multiplicative (@scalar_mx n).
Canonical scalar_mx_rmorphism := AddRMorphism scalar_mx_is_multiplicative.

End MatrixRing.

Section LiftPerm.

Block expresssion of a lifted permutation matrix, for the Cormen LUP.

Variable n : nat.

These could be in zmodp, but that would introduce a dependency on perm.

Definition lift0_perm s : 'S_n.+1 := lift_perm 0 0 s.

Lemma lift0_perm0 s : lift0_perm s 0 = 0.

Lemma lift0_perm_lift s k' :
  lift0_perm s (lift 0 k') = lift (0 : 'I_n.+1) (s k').

Lemma lift0_permK s : cancel (lift0_perm s) (lift0_perm s^-1).

Lemma lift0_perm_eq0 s i : (lift0_perm s i == 0) =