Library mathcomp.algebra.matrix

(* (c) Copyright 2006-2016 Microsoft Corporation and Inria.                  
 Distributed under the terms of CeCILL-B.                                  *)

From HB Require Import structures.
From mathcomp Require Import ssreflect ssrbool ssrfun eqtype ssrnat seq choice.
From mathcomp Require Import fintype finfun finset fingroup perm order div.
From mathcomp Require Import prime binomial ssralg countalg finalg zmodp bigop.

Basic concrete linear algebra : definition of type for matrices, and all basic matrix operations including determinant, trace and support for block decomposition. Matrices are represented by a row-major list of their coefficients but this implementation is hidden by three levels of wrappers (Matrix/Finfun/Tuple) so the matrix type should be treated as abstract and handled using only the operations described below: 'M[R]_(m, n) == the type of m rows by n columns matrices with 'M_(m, n) coefficients in R; the [R] is optional and is usually omitted. 'M[R]_n, 'M_n == the type of n x n square matrices. 'rV[R]_n, 'rV_n == the type of 1 x n row vectors. 'cV[R]_n, 'cV_n == the type of n x 1 column vectors. \matrix_(i < m, j < n) Expr(i, j) == the m x n matrix with general coefficient Expr(i, j), with i : 'I_m and j : 'I_n. the < m bound can be omitted if it is equal to n, though usually both bounds are omitted as they can be inferred from the context. \row_(j < n) Expr(j), \col_(i < m) Expr(i) the row / column vectors with general term Expr; the parentheses can be omitted along with the bound. \matrix_(i < m) RowExpr(i) == the m x n matrix with row i given by RowExpr(i) : 'rV_n. A i j == the coefficient of matrix A : 'M_(m, n) in column j of row i, where i : 'I_m, and j : 'I_n (via the coercion fun_of_matrix : matrix >-> Funclass). const_mx a == the constant matrix whose entries are all a (dimensions should be determined by context). map_mx f A == the pointwise image of A by f, i.e., the matrix Af congruent to A with Af i j = f (A i j) for all i and j. map2_mx f A B == the pointwise image of A and B by f, i.e., the matrix ABf congruent to A with ABf i j = f (A i j) (B i j) for all i and j. A^T == the matrix transpose of A. row i A == the i'th row of A (this is a row vector). col j A == the j'th column of A (a column vector). row' i A == A with the i'th row spliced out. col' i A == A with the j'th column spliced out. xrow i1 i2 A == A with rows i1 and i2 interchanged. xcol j1 j2 A == A with columns j1 and j2 interchanged. row_perm s A == A : 'M_(m, n) with rows permuted by s : 'S_m. col_perm s A == A : 'M_(m, n) with columns permuted by s : 'S_n. row_mx Al Ar == the row block matrix <Al Ar> obtained by concatenating two matrices Al and Ar of the same height. col_mx Au Ad == the column block matrix / Au \ (Au and Ad must have the same width). \ Ad / block_mx Aul Aur Adl Adr == the block matrix / Aul Aur \ \ Adl Adr / \mxblock_(i < m, j < n) B i j == the block matrix of type 'M_(\sum_i p_ i, \sum_j q_ j) / (B 0 0) ⋯ (B 0 j) ⋯ (B 0 n) \ | ... ... ... | | (B i 0) ⋯ (B i j) ⋯ (B i n) | | ... ... ... | \ (B m 0) ⋯ (B m j) ⋯ (B m n) / where each block (B i j) has type 'M_(p_ i, q_ j). \mxdiag_(i < n) B i == the block square matrix of type 'M_(\sum_i p_ i) / (B 0) 0 \ | ... ... | | 0 (B i) 0 | | ... ... | \ 0 (B n) / where each block (B i) has type 'M_(p_ i). \mxrow_(j < n) B j == the block matrix of type 'M_(m, \sum_j q_ j). < (B 0) ... (B n) > where each block (B j) has type 'M_(m, q_ j). \mxcol_(i < m) B i == the block matrix of type 'M_(\sum_i p_ i, n) / (B 0) \ | ... | \ (B m) / where each block (B i) has type 'M(p_ i, n). [l|r]submx A == the left/right submatrices of a row block matrix A. Note that the type of A, 'M_(m, n1 + n2) indicates how A should be decomposed. [u|d]submx A == the up/down submatrices of a column block matrix A. [u|d] [l|r]submx A == the upper left, etc submatrices of a block matrix A. submxblock A i j == the block submatrix of type 'M_(p_ i, q_ j) of A. The type of A, 'M_(\sum_i p_ i, \sum_i q_ i) indicates how A should be decomposed. There is no analogous for mxdiag since one can use submxblock A i i to extract a diagonal block. submxrow A j == the submatrix of type 'M_(m, q_ j) of A. The type of A, 'M_(m, \sum_j q_ j) indicates how A should be decomposed. submxrow A j == the submatrix of type 'M_(p_ i, n) of A. The type of A, 'M_(\sum_i p_ i, n) indicates how A should be decomposed. mxsub f g A == generic reordered submatrix, given by functions f and g which specify which subset of rows and columns to take and how to reorder them, e.g. picking f and g to be increasing yields traditional submatrices. := \matrix_(i, j) A (f i) (g i) rowsub f A := mxsub f id A colsub g A := mxsub id g A castmx eq_mn A == A : 'M_(m, n) cast to 'M_(m', n') using the equation pair eq_mn : (m = m') * (n = n'). This is the usual workaround for the syntactic limitations of dependent types in Coq, and can be used to introduce a block decomposition. It simplifies to A when eq_mn is the pair (erefl m, erefl n) (using rewrite /castmx /=). conform_mx B A == A if A and B have the same dimensions, else B. mxvec A == a row vector of width m * n holding all the entries of the m x n matrix A. mxvec_index i j == the index of A i j in mxvec A. vec_mx v == the inverse of mxvec, reshaping a vector of width m * n back into into an m x n rectangular matrix. In 'M[R]_(m, n), R can be any type, but 'M[R]_(m, n) inherits the eqType, choiceType, countType, finType, zmodType structures of R; 'M[R]_(m, n) also has a natural lmodType R structure when R has a ringType structure. Because the type of matrices specifies their dimension, only non-trivial square matrices (of type 'M[R]_n.+1) can inherit the ring structure of R; indeed they then have an algebra structure (lalgType R, or algType R if R is a comRingType, or even unitAlgType if R is a comUnitRingType). We thus provide separate syntax for the general matrix multiplication, and other operations for matrices over a ringType R: A *m B == the matrix product of A and B; the width of A must be equal to the height of B. a%:M == the scalar matrix with a's on the main diagonal; in particular 1%:M denotes the identity matrix, and is equal to 1%R when n is of the form n'.+1 (e.g., n >= 1). is_scalar_mx A <=> A is a scalar matrix (A = a%:M for some A). diag_mx d == the diagonal matrix whose main diagonal is d : 'rV_n. is_diag_mx A <=> A is a diagonal matrix: forall i j, i != j -> A i j = 0 is_trig_mx A <=> A is a triangular matrix: forall i j, i < j -> A i j = 0 delta_mx i j == the matrix with a 1 in row i, column j and 0 elsewhere. pid_mx r == the partial identity matrix with 1s only on the r first coefficients of the main diagonal; the dimensions of pid_mx r are determined by the context, and pid_mx r can be rectangular. copid_mx r == the complement to 1%:M of pid_mx r: a square diagonal matrix with 1s on all but the first r coefficients on its main diagonal. perm_mx s == the n x n permutation matrix for s : 'S_n. tperm_mx i1 i2 == the permutation matrix that exchanges i1 i2 : 'I_n. is_perm_mx A == A is a permutation matrix. lift0_mx A == the 1 + n square matrix block_mx 1 0 0 A when A : 'M_n. \tr A == the trace of a square matrix A. \det A == the determinant of A, using the Leibnitz formula. cofactor i j A == the i, j cofactor of A (the signed i, j minor of A), \adj A == the adjugate matrix of A (\adj A i j = cofactor j i A). A \in unitmx == A is invertible (R must be a comUnitRingType). invmx A == the inverse matrix of A if A \in unitmx A, otherwise A. A \is a mxOver S == the matrix A has its coefficients in S. comm_mx A B := A *m B = B *m A comm_mxb A B := A *m B == B *m A all_comm_mx As fs := all2rel comm_mxb fs The following operations provide a correspondence between linear functions and matrices: lin1_mx f == the m x n matrix that emulates via right product a (linear) function f : 'rV_m -> 'rV_n on ROW VECTORS lin_mx f == the (m1 * n1) x (m2 * n2) matrix that emulates, via the right multiplication on the mxvec encodings, a linear function f : 'M_(m1, n1) -> 'M_(m2, n2) lin_mul_row u := lin1_mx (mulmx u \o vec_mx) (applies a row-encoded function to the row-vector u). mulmx A == partially applied matrix multiplication (mulmx A B is displayed as A *m B), with, for A : 'M_(m, n), a canonical {linear 'M_(n, p) -> 'M(m, p}} structure. mulmxr A == self-simplifying right-hand matrix multiplication, i.e., mulmxr A B simplifies to B *m A, with, for A : 'M_(n, p), a canonical {linear 'M_(m, n) -> 'M(m, p}} structure. lin_mulmx A := lin_mx (mulmx A). lin_mulmxr A := lin_mx (mulmxr A). We also extend any finType structure of R to 'M[R]_(m, n), and define: {'GL_n[R]} == the finGroupType of units of 'M[R]_n.-1.+1. 'GL_n[R] == the general linear group of all matrices in {'GL_n(R)}. 'GL_n(p) == 'GL_n['F_p], the general linear group of a prime field. GLval u == the coercion of u : {'GL_n(R)} to a matrix. In addition to the lemmas relevant to these definitions, this file also proves several classic results, including :
  • The determinant is a multilinear alternate form.
  • The Laplace determinant expansion formulas: expand_det_ [row|col].
  • The Cramer rule : mul_mx_adj & mul_adj_mx.
Vandermonde m a == the 'M[R]_(m, n) Vandermonde matrix, given a : 'rV_n / 1 ... 1 \ | (a 0 0) ... (a 0 (n - 1)) | | (a 0 0 ^+ 2) ... (a 0 (n - 1) ^+ 2) | | ... ... | \ (a 0 0 ^+ (m - 1)) ... (a 0 (n - 1) ^+ (m - 1)) / := \matrix_(i < m, j < n) a 0 j ^+ i. Finally, as an example of the use of block products, we program and prove the correctness of a classical linear algebra algorithm: cormen_lup A == the triangular decomposition (L, U, P) of a nontrivial square matrix A into a lower triagular matrix L with 1s on the main diagonal, an upper matrix U, and a permutation matrix P, such that P * A = L * U. This is example only; we use a different, more precise algorithm to develop the theory of matrix ranks and row spaces in mxalgebra.v

Set Implicit Arguments.

Import GroupScope.
Import GRing.Theory.
Local Open Scope ring_scope.

Reserved Notation "''M_' n" (at level 8, n at level 2, format "''M_' n").
Reserved Notation "''rV_' n" (at level 8, n at level 2, format "''rV_' n").
Reserved Notation "''cV_' n" (at level 8, n at level 2, format "''cV_' n").
Reserved Notation "''M_' ( n )" (at level 8). (* only parsing *)
Reserved Notation "''M_' ( m , n )" (at level 8, format "''M_' ( m , n )").
Reserved Notation "''M[' R ]_ n" (at level 8, n at level 2). (* only parsing *)
Reserved Notation "''rV[' R ]_ n" (at level 8, n at level 2). (* only parsing *)
Reserved Notation "''cV[' R ]_ n" (at level 8, n at level 2). (* only parsing *)
Reserved Notation "''M[' R ]_ ( n )" (at level 8). (* only parsing *)
Reserved Notation "''M[' R ]_ ( m , n )" (at level 8). (* only parsing *)

Reserved Notation "\matrix_ i E"
  (at level 36, E at level 36, i at level 2,
   format "\matrix_ i E").
Reserved Notation "\matrix_ ( i < n ) E"
  (at level 36, E at level 36, i, n at level 50). (* only parsing *)
Reserved Notation "\matrix_ ( i , j ) E"
  (at level 36, E at level 36, i, j at level 50,
   format "\matrix_ ( i , j ) E").
Reserved Notation "\matrix[ k ]_ ( i , j ) E"
  (at level 36, E at level 36, i, j at level 50,
   format "\matrix[ k ]_ ( i , j ) E").
Reserved Notation "\matrix_ ( i < m , j < n ) E"
  (at level 36, E at level 36, i, m, j, n at level 50). (* only parsing *)
Reserved Notation "\matrix_ ( i , j < n ) E"
  (at level 36, E at level 36, i, j, n at level 50). (* only parsing *)
Reserved Notation "\row_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\row_ j E").
Reserved Notation "\row_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50). (* only parsing *)
Reserved Notation "\col_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\col_ j E").
Reserved Notation "\col_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50). (* only parsing *)
Reserved Notation "\mxblock_ ( i , j ) E"
  (at level 36, E at level 36, i, j at level 50,
   format "\mxblock_ ( i , j ) E").
Reserved Notation "\mxblock_ ( i < m , j < n ) E"
  (at level 36, E at level 36, i, m, j, n at level 50). (* only parsing *)
Reserved Notation "\mxblock_ ( i , j < n ) E"
  (at level 36, E at level 36, i, j, n at level 50). (* only parsing *)
Reserved Notation "\mxrow_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\mxrow_ j E").
Reserved Notation "\mxrow_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50). (* only parsing *)
Reserved Notation "\mxcol_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\mxcol_ j E").
Reserved Notation "\mxcol_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50). (* only parsing *)
Reserved Notation "\mxdiag_ j E"
  (at level 36, E at level 36, j at level 2,
   format "\mxdiag_ j E").
Reserved Notation "\mxdiag_ ( j < n ) E"
  (at level 36, E at level 36, j, n at level 50). (* only parsing *)

Reserved Notation "x %:M" (at level 8, format "x %:M").
Reserved Notation "A *m B" (at level 40, left associativity, format "A *m B").
Reserved Notation "A ^T" (at level 8, format "A ^T").
Reserved Notation "\tr A" (at level 10, A at level 8, format "\tr A").
Reserved Notation "\det A" (at level 10, A at level 8, format "\det A").
Reserved Notation "\adj A" (at level 10, A at level 8, format "\adj A").

Local Notation simp := (Monoid.Theory.simpm, oppr0).

Type Definition*********************************

Section MatrixDef.

Variable R : Type.
Variables m n : nat.

Basic linear algebra (matrices). We use dependent types (ordinals) for the indices so that ranges are mostly inferred automatically

Variant matrix : predArgType := Matrix of {ffun 'I_m × 'I_n R}.

Definition mx_val A := let: Matrix g := A in g.


Definition fun_of_matrix A (i : 'I_m) (j : 'I_n) := mx_val A (i, j).

Coercion fun_of_matrix : matrix >-> Funclass.

End MatrixDef.

Fact matrix_key : unit.

Canonical matrix_unlockable := Unlockable matrix_of_fun.unlock.

Section MatrixDef2.

Variable R : Type.
Variables m n : nat.
Implicit Type F : 'I_m 'I_n R.

Lemma mxE k F : matrix_of_fun k F =2 F.

Lemma matrixP (A B : matrix R m n) : A =2 B A = B.

Lemma eq_mx k F1 F2 : (F1 =2 F2) matrix_of_fun k F1 = matrix_of_fun k F2.

End MatrixDef2.

Arguments eq_mx {R m n k} [F1] F2 eq_F12.

Bind Scope ring_scope with matrix.

Notation "''M[' R ]_ ( m , n )" := (matrix R m n) (only parsing): type_scope.
Notation "''rV[' R ]_ n" := 'M[R]_(1, n) (only parsing) : type_scope.
Notation "''cV[' R ]_ n" := 'M[R]_(n, 1) (only parsing) : type_scope.
Notation "''M[' R ]_ n" := 'M[R]_(n, n) (only parsing) : type_scope.
Notation "''M[' R ]_ ( n )" := 'M[R]_n (only parsing) : type_scope.
Notation "''M_' ( m , n )" := 'M[_]_(m, n) : type_scope.
Notation "''rV_' n" := 'M_(1, n) : type_scope.
Notation "''cV_' n" := 'M_(n, 1) : type_scope.
Notation "''M_' n" := 'M_(n, n) : type_scope.
Notation "''M_' ( n )" := 'M_n (only parsing) : type_scope.

Notation "\matrix[ k ]_ ( i , j ) E" := (matrix_of_fun k (fun i jE)) :
   ring_scope.

Notation "\matrix_ ( i < m , j < n ) E" :=
  (@matrix_of_fun _ m n matrix_key (fun i jE)) (only parsing) : ring_scope.

Notation "\matrix_ ( i , j < n ) E" :=
  (\matrix_(i < n, j < n) E) (only parsing) : ring_scope.

Notation "\matrix_ ( i , j ) E" := (\matrix_(i < _, j < _) E) : ring_scope.

Notation "\matrix_ ( i < m ) E" :=
  (\matrix_(i < m, j < _) @fun_of_matrix _ 1 _ E 0 j)
  (only parsing) : ring_scope.
Notation "\matrix_ i E" := (\matrix_(i < _) E) : ring_scope.

Notation "\col_ ( i < n ) E" := (@matrix_of_fun _ n 1 matrix_key (fun i _E))
  (only parsing) : ring_scope.
Notation "\col_ i E" := (\col_(i < _) E) : ring_scope.

Notation "\row_ ( j < n ) E" := (@matrix_of_fun _ 1 n matrix_key (fun _ jE))
  (only parsing) : ring_scope.
Notation "\row_ j E" := (\row_(j < _) E) : ring_scope.


Lemma card_mx (F : finType) m n : (#|{: 'M[F]_(m, n)}| = #|F| ^ (m × n))%N.

Matrix structural operations (transpose, permutation, blocks) *******

Section MatrixStructural.

Variable R : Type.

Constant matrix
Fact const_mx_key : unit.
Definition const_mx m n a : 'M[R]_(m, n) := \matrix[const_mx_key]_(i, j) a.
Arguments const_mx {m n}.

Section FixedDim.
Definitions and properties for which we can work with fixed dimensions.

Variables m n : nat.
Implicit Type A : 'M[R]_(m, n).

Reshape a matrix, to accommodate the block functions for instance.
Definition castmx m' n' (eq_mn : (m = m') × (n = n')) A : 'M_(m', n') :=
  let: erefl in _ = m' := eq_mn.1 return 'M_(m', n') in
  let: erefl in _ = n' := eq_mn.2 return 'M_(m, n') in A.

Definition conform_mx m' n' B A :=
  match m =P m', n =P n' with
  | ReflectT eq_m, ReflectT eq_ncastmx (eq_m, eq_n) A
  | _, _B
  end.

Transpose a matrix
Fact trmx_key : unit.
Definition trmx A := \matrix[trmx_key]_(i, j) A j i.

Permute a matrix vertically (rows) or horizontally (columns)
Fact row_perm_key : unit.
Definition row_perm (s : 'S_m) A := \matrix[row_perm_key]_(i, j) A (s i) j.
Fact col_perm_key : unit.
Definition col_perm (s : 'S_n) A := \matrix[col_perm_key]_(i, j) A i (s j).

Exchange two rows/columns of a matrix
Definition xrow i1 i2 := row_perm (tperm i1 i2).
Definition xcol j1 j2 := col_perm (tperm j1 j2).

Row/Column sub matrices of a matrix
Definition row i0 A := \row_j A i0 j.
Definition col j0 A := \col_i A i j0.

Removing a row/column from a matrix
Definition row' i0 A := \matrix_(i, j) A (lift i0 i) j.
Definition col' j0 A := \matrix_(i, j) A i (lift j0 j).

reindexing/subindex a matrix
Definition mxsub m' n' f g A := \matrix_(i < m', j < n') A (f i) (g j).
Local Notation colsub g := (mxsub id g).
Local Notation rowsub f := (mxsub f id).

Lemma castmx_const m' n' (eq_mn : (m = m') × (n = n')) a :
  castmx eq_mn (const_mx a) = const_mx a.

Lemma trmx_const a : trmx (const_mx a) = const_mx a.

Lemma row_perm_const s a : row_perm s (const_mx a) = const_mx a.

Lemma col_perm_const s a : col_perm s (const_mx a) = const_mx a.

Lemma xrow_const i1 i2 a : xrow i1 i2 (const_mx a) = const_mx a.

Lemma xcol_const j1 j2 a : xcol j1 j2 (const_mx a) = const_mx a.

Lemma rowP (u v : 'rV[R]_n) : u 0 =1 v 0 u = v.

Lemma rowK u_ i0 : row i0 (\matrix_i u_ i) = u_ i0.

Lemma row_matrixP A B : ( i, row i A = row i B) A = B.

Lemma colP (u v : 'cV[R]_m) : u^~ 0 =1 v^~ 0 u = v.

Lemma row_const i0 a : row i0 (const_mx a) = const_mx a.

Lemma col_const j0 a : col j0 (const_mx a) = const_mx a.

Lemma row'_const i0 a : row' i0 (const_mx a) = const_mx a.

Lemma col'_const j0 a : col' j0 (const_mx a) = const_mx a.

Lemma col_perm1 A : col_perm 1 A = A.

Lemma row_perm1 A : row_perm 1 A = A.

Lemma col_permM s t A : col_perm (s × t) A = col_perm s (col_perm t A).

Lemma row_permM s t A : row_perm (s × t) A = row_perm s (row_perm t A).

Lemma col_row_permC s t A :
  col_perm s (row_perm t A) = row_perm t (col_perm s A).

Lemma rowEsub i : row i = rowsub (fun i).
Lemma colEsub j : col j = colsub (fun j).

Lemma row'Esub i : row' i = rowsub (lift i).
Lemma col'Esub j : col' j = colsub (lift j).

Lemma row_permEsub s : row_perm s = rowsub s.

Lemma col_permEsub s : col_perm s = colsub s.

Lemma xrowEsub i1 i2 : xrow i1 i2 = rowsub (tperm i1 i2).

Lemma xcolEsub j1 j2 : xcol j1 j2 = colsub (tperm j1 j2).

Lemma mxsub_id : mxsub id id =1 id.

Lemma eq_mxsub m' n' f f' g g' : f =1 f' g =1 g'
  @mxsub m' n' f g =1 mxsub f' g'.

Lemma eq_rowsub m' (f f' : 'I_m' 'I_m) : f =1 f' rowsub f =1 rowsub f'.

Lemma eq_colsub n' (g g' : 'I_n' 'I_n) : g =1 g' colsub g =1 colsub g'.

Lemma mxsub_eq_id f g : f =1 id g =1 id mxsub f g =1 id.

Lemma mxsub_eq_colsub n' f g : f =1 id @mxsub _ n' f g =1 colsub g.

Lemma mxsub_eq_rowsub m' f g : g =1 id @mxsub m' _ f g =1 rowsub f.

Lemma mxsub_ffunl m' n' f g : @mxsub m' n' (finfun f) g =1 mxsub f g.

Lemma mxsub_ffunr m' n' f g : @mxsub m' n' f (finfun g) =1 mxsub f g.

Lemma mxsub_ffun m' n' f g : @mxsub m' n' (finfun f) (finfun g) =1 mxsub f g.

Lemma mxsub_const m' n' f g a : @mxsub m' n' f g (const_mx a) = const_mx a.

End FixedDim.

Local Notation colsub g := (mxsub id g).
Local Notation rowsub f := (mxsub f id).
Local Notation "A ^T" := (trmx A) : ring_scope.

Lemma castmx_id m n erefl_mn (A : 'M_(m, n)) : castmx erefl_mn A = A.

Lemma castmx_comp m1 n1 m2 n2 m3 n3 (eq_m1 : m1 = m2) (eq_n1 : n1 = n2)
                                    (eq_m2 : m2 = m3) (eq_n2 : n2 = n3) A :
  castmx (eq_m2, eq_n2) (castmx (eq_m1, eq_n1) A)
    = castmx (etrans eq_m1 eq_m2, etrans eq_n1 eq_n2) A.

Lemma castmxK m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) :
  cancel (castmx (eq_m, eq_n)) (castmx (esym eq_m, esym eq_n)).

Lemma castmxKV m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) :
  cancel (castmx (esym eq_m, esym eq_n)) (castmx (eq_m, eq_n)).

This can be use to reverse an equation that involves a cast.
Lemma castmx_sym m1 n1 m2 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) A1 A2 :
  A1 = castmx (eq_m, eq_n) A2 A2 = castmx (esym eq_m, esym eq_n) A1.

Lemma eq_castmx m1 n1 m2 n2 (eq_mn eq_mn' : (m1 = m2) × (n1 = n2)) :
  castmx eq_mn =1 castmx eq_mn'.

Lemma castmxE m1 n1 m2 n2 (eq_mn : (m1 = m2) × (n1 = n2)) A i j :
  castmx eq_mn A i j =
     A (cast_ord (esym eq_mn.1) i) (cast_ord (esym eq_mn.2) j).

Lemma conform_mx_id m n (B A : 'M_(m, n)) : conform_mx B A = A.

Lemma nonconform_mx m m' n n' (B : 'M_(m', n')) (A : 'M_(m, n)) :
  (m != m') || (n != n') conform_mx B A = B.

Lemma conform_castmx m1 n1 m2 n2 m3 n3
                     (e_mn : (m2 = m3) × (n2 = n3)) (B : 'M_(m1, n1)) A :
  conform_mx B (castmx e_mn A) = conform_mx B A.

Lemma trmxK m n : cancel (@trmx m n) (@trmx n m).

Lemma trmx_inj m n : injective (@trmx m n).

Lemma trmx_cast m1 n1 m2 n2 (eq_mn : (m1 = m2) × (n1 = n2)) A :
  (castmx eq_mn A)^T = castmx (eq_mn.2, eq_mn.1) A^T.

Lemma trmx_conform m' n' m n (B : 'M_(m', n')) (A : 'M_(m, n)) :
  (conform_mx B A)^T = conform_mx B^T A^T.

Lemma tr_row_perm m n s (A : 'M_(m, n)) : (row_perm s A)^T = col_perm s A^T.

Lemma tr_col_perm m n s (A : 'M_(m, n)) : (col_perm s A)^T = row_perm s A^T.

Lemma tr_xrow m n i1 i2 (A : 'M_(m, n)) : (xrow i1 i2 A)^T = xcol i1 i2 A^T.

Lemma tr_xcol m n j1 j2 (A : 'M_(m, n)) : (xcol j1 j2 A)^T = xrow j1 j2 A^T.

Lemma row_id n i (V : 'rV_n) : row i V = V.

Lemma col_id n j (V : 'cV_n) : col j V = V.

Lemma row_eq m1 m2 n i1 i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row i1 A1 = row i2 A2 A1 i1 =1 A2 i2.

Lemma col_eq m n1 n2 j1 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col j1 A1 = col j2 A2 A1^~ j1 =1 A2^~ j2.

Lemma row'_eq m n i0 (A B : 'M_(m, n)) :
  row' i0 A = row' i0 B {in predC1 i0, A =2 B}.

Lemma col'_eq m n j0 (A B : 'M_(m, n)) :
  col' j0 A = col' j0 B i, {in predC1 j0, A i =1 B i}.

Lemma tr_row m n i0 (A : 'M_(m, n)) : (row i0 A)^T = col i0 A^T.

Lemma tr_row' m n i0 (A : 'M_(m, n)) : (row' i0 A)^T = col' i0 A^T.

Lemma tr_col m n j0 (A : 'M_(m, n)) : (col j0 A)^T = row j0 A^T.

Lemma tr_col' m n j0 (A : 'M_(m, n)) : (col' j0 A)^T = row' j0 A^T.

Lemma mxsub_comp m1 m2 m3 n1 n2 n3
  (f : 'I_m2 'I_m1) (f' : 'I_m3 'I_m2)
  (g : 'I_n2 'I_n1) (g' : 'I_n3 'I_n2) (A : 'M_(m1, n1)) :
  mxsub (f \o f') (g \o g') A = mxsub f' g' (mxsub f g A).

Lemma rowsub_comp m1 m2 m3 n
  (f : 'I_m2 'I_m1) (f' : 'I_m3 'I_m2) (A : 'M_(m1, n)) :
  rowsub (f \o f') A = rowsub f' (rowsub f A).

Lemma colsub_comp m n n2 n3
  (g : 'I_n2 'I_n) (g' : 'I_n3 'I_n2) (A : 'M_(m, n)) :
  colsub (g \o g') A = colsub g' (colsub g A).

Lemma mxsubrc m1 m2 n n2 f g (A : 'M_(m1, n)) :
  mxsub f g A = rowsub f (colsub g A) :> 'M_(m2, n2).

Lemma mxsubcr m1 m2 n n2 f g (A : 'M_(m1, n)) :
  mxsub f g A = colsub g (rowsub f A) :> 'M_(m2, n2).

Lemma rowsub_cast m1 m2 n (eq_m : m1 = m2) (A : 'M_(m2, n)) :
  rowsub (cast_ord eq_m) A = castmx (esym eq_m, erefl) A.

Lemma colsub_cast m n1 n2 (eq_n : n1 = n2) (A : 'M_(m, n2)) :
  colsub (cast_ord eq_n) A = castmx (erefl, esym eq_n) A.

Lemma mxsub_cast m1 m2 n1 n2 (eq_m : m1 = m2) (eq_n : n1 = n2) A :
  mxsub (cast_ord eq_m) (cast_ord eq_n) A = castmx (esym eq_m, esym eq_n) A.

Lemma castmxEsub m1 m2 n1 n2 (eq_mn : (m1 = m2) × (n1 = n2)) A :
  castmx eq_mn A = mxsub (cast_ord (esym eq_mn.1)) (cast_ord (esym eq_mn.2)) A.

Lemma trmx_mxsub m1 m2 n1 n2 f g (A : 'M_(m1, n1)) :
  (mxsub f g A)^T = mxsub g f A^T :> 'M_(n2, m2).

Lemma row_mxsub m1 m2 n1 n2
    (f : 'I_m2 'I_m1) (g : 'I_n2 'I_n1) (A : 'M_(m1, n1)) i :
  row i (mxsub f g A) = row (f i) (colsub g A).

Lemma col_mxsub m1 m2 n1 n2
    (f : 'I_m2 'I_m1) (g : 'I_n2 'I_n1) (A : 'M_(m1, n1)) i :
 col i (mxsub f g A) = col (g i) (rowsub f A).

Lemma row_rowsub m1 m2 n (f : 'I_m2 'I_m1) (A : 'M_(m1, n)) i :
  row i (rowsub f A) = row (f i) A.

Lemma col_colsub m n1 n2 (g : 'I_n2 'I_n1) (A : 'M_(m, n1)) i :
  col i (colsub g A) = col (g i) A.

Ltac split_mxE := apply/matrixPi j; do ![rewrite mxE | case: split ⇒ ?].

Section CutPaste.

Variables m m1 m2 n n1 n2 : nat.

Concatenating two matrices, in either direction.

Fact row_mx_key : unit.
Definition row_mx (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) : 'M[R]_(m, n1 + n2) :=
  \matrix[row_mx_key]_(i, j)
     match split j with inl j1A1 i j1 | inr j2A2 i j2 end.

Fact col_mx_key : unit.
Definition col_mx (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) : 'M[R]_(m1 + m2, n) :=
  \matrix[col_mx_key]_(i, j)
     match split i with inl i1A1 i1 j | inr i2A2 i2 j end.

Left/Right | Up/Down submatrices of a rows | columns matrix. The shape of the (dependent) width parameters of the type of A determines which submatrix is selected.

Fact lsubmx_key : unit.
Definition lsubmx (A : 'M[R]_(m, n1 + n2)) :=
  \matrix[lsubmx_key]_(i, j) A i (lshift n2 j).

Fact rsubmx_key : unit.
Definition rsubmx (A : 'M[R]_(m, n1 + n2)) :=
  \matrix[rsubmx_key]_(i, j) A i (rshift n1 j).

Fact usubmx_key : unit.
Definition usubmx (A : 'M[R]_(m1 + m2, n)) :=
  \matrix[usubmx_key]_(i, j) A (lshift m2 i) j.

Fact dsubmx_key : unit.
Definition dsubmx (A : 'M[R]_(m1 + m2, n)) :=
  \matrix[dsubmx_key]_(i, j) A (rshift m1 i) j.

Lemma row_mxEl A1 A2 i j : row_mx A1 A2 i (lshift n2 j) = A1 i j.

Lemma row_mxKl A1 A2 : lsubmx (row_mx A1 A2) = A1.

Lemma row_mxEr A1 A2 i j : row_mx A1 A2 i (rshift n1 j) = A2 i j.

Lemma row_mxKr A1 A2 : rsubmx (row_mx A1 A2) = A2.

Lemma hsubmxK A : row_mx (lsubmx A) (rsubmx A) = A.

Lemma col_mxEu A1 A2 i j : col_mx A1 A2 (lshift m2 i) j = A1 i j.

Lemma col_mxKu A1 A2 : usubmx (col_mx A1 A2) = A1.

Lemma col_mxEd A1 A2 i j : col_mx A1 A2 (rshift m1 i) j = A2 i j.

Lemma col_mxKd A1 A2 : dsubmx (col_mx A1 A2) = A2.

Lemma lsubmxEsub : lsubmx = colsub (lshift _).

Lemma rsubmxEsub : rsubmx = colsub (@rshift _ _).

Lemma usubmxEsub : usubmx = rowsub (lshift _).

Lemma dsubmxEsub : dsubmx = rowsub (@rshift _ _).

Lemma eq_row_mx A1 A2 B1 B2 : row_mx A1 A2 = row_mx B1 B2 A1 = B1 A2 = B2.

Lemma eq_col_mx A1 A2 B1 B2 : col_mx A1 A2 = col_mx B1 B2 A1 = B1 A2 = B2.

Lemma row_mx_const a : row_mx (const_mx a) (const_mx a) = const_mx a.

Lemma col_mx_const a : col_mx (const_mx a) (const_mx a) = const_mx a.

Lemma row_usubmx A i : row i (usubmx A) = row (lshift m2 i) A.

Lemma row_dsubmx A i : row i (dsubmx A) = row (rshift m1 i) A.

Lemma col_lsubmx A i : col i (lsubmx A) = col (lshift n2 i) A.

Lemma col_rsubmx A i : col i (rsubmx A) = col (rshift n1 i) A.

End CutPaste.

Lemma row_thin_mx m n (A : 'M_(m,0)) (B : 'M_(m,n)) : row_mx A B = B.

Lemma col_flat_mx m n (A : 'M_(0,n)) (B : 'M_(m,n)) : col_mx A B = B.

Lemma trmx_lsub m n1 n2 (A : 'M_(m, n1 + n2)) : (lsubmx A)^T = usubmx A^T.

Lemma trmx_rsub m n1 n2 (A : 'M_(m, n1 + n2)) : (rsubmx A)^T = dsubmx A^T.

Lemma tr_row_mx m n1 n2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  (row_mx A1 A2)^T = col_mx A1^T A2^T.

Lemma tr_col_mx m1 m2 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  (col_mx A1 A2)^T = row_mx A1^T A2^T.

Lemma trmx_usub m1 m2 n (A : 'M_(m1 + m2, n)) : (usubmx A)^T = lsubmx A^T.

Lemma trmx_dsub m1 m2 n (A : 'M_(m1 + m2, n)) : (dsubmx A)^T = rsubmx A^T.

Lemma vsubmxK m1 m2 n (A : 'M_(m1 + m2, n)) : col_mx (usubmx A) (dsubmx A) = A.

Lemma cast_row_mx m m' n1 n2 (eq_m : m = m') A1 A2 :
  castmx (eq_m, erefl _) (row_mx A1 A2)
    = row_mx (castmx (eq_m, erefl n1) A1) (castmx (eq_m, erefl n2) A2).

Lemma cast_col_mx m1 m2 n n' (eq_n : n = n') A1 A2 :
  castmx (erefl _, eq_n) (col_mx A1 A2)
    = col_mx (castmx (erefl m1, eq_n) A1) (castmx (erefl m2, eq_n) A2).

This lemma has Prenex Implicits to help RL rewriting with castmx_sym.
Lemma row_mxA m n1 n2 n3 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) (A3 : 'M_(m, n3)) :
  let cast := (erefl m, esym (addnA n1 n2 n3)) in
  row_mx A1 (row_mx A2 A3) = castmx cast (row_mx (row_mx A1 A2) A3).
Definition row_mxAx := row_mxA. (* bypass Prenex Implicits. *)

This lemma has Prenex Implicits to help RL rewrititng with castmx_sym.
Lemma col_mxA m1 m2 m3 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) (A3 : 'M_(m3, n)) :
  let cast := (esym (addnA m1 m2 m3), erefl n) in
  col_mx A1 (col_mx A2 A3) = castmx cast (col_mx (col_mx A1 A2) A3).
Definition col_mxAx := col_mxA. (* bypass Prenex Implicits. *)

Lemma row_row_mx m n1 n2 i0 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  row i0 (row_mx A1 A2) = row_mx (row i0 A1) (row i0 A2).

Lemma col_col_mx m1 m2 n j0 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  col j0 (col_mx A1 A2) = col_mx (col j0 A1) (col j0 A2).

Lemma row'_row_mx m n1 n2 i0 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  row' i0 (row_mx A1 A2) = row_mx (row' i0 A1) (row' i0 A2).

Lemma col'_col_mx m1 m2 n j0 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  col' j0 (col_mx A1 A2) = col_mx (col' j0 A1) (col' j0 A2).

Lemma colKl m n1 n2 j1 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col (lshift n2 j1) (row_mx A1 A2) = col j1 A1.

Lemma colKr m n1 n2 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col (rshift n1 j2) (row_mx A1 A2) = col j2 A2.

Lemma rowKu m1 m2 n i1 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row (lshift m2 i1) (col_mx A1 A2) = row i1 A1.

Lemma rowKd m1 m2 n i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row (rshift m1 i2) (col_mx A1 A2) = row i2 A2.

Lemma col'Kl m n1 n2 j1 (A1 : 'M_(m, n1.+1)) (A2 : 'M_(m, n2)) :
  col' (lshift n2 j1) (row_mx A1 A2) = row_mx (col' j1 A1) A2.

Lemma row'Ku m1 m2 n i1 (A1 : 'M_(m1.+1, n)) (A2 : 'M_(m2, n)) :
  row' (lshift m2 i1) (@col_mx m1.+1 m2 n A1 A2) = col_mx (row' i1 A1) A2.

Lemma mx'_cast m n : 'I_n (m + n.-1)%N = (m + n).-1.

Lemma col'Kr m n1 n2 j2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) :
  col' (rshift n1 j2) (@row_mx m n1 n2 A1 A2)
    = castmx (erefl m, mx'_cast n1 j2) (row_mx A1 (col' j2 A2)).

Lemma row'Kd m1 m2 n i2 (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) :
  row' (rshift m1 i2) (col_mx A1 A2)
    = castmx (mx'_cast m1 i2, erefl n) (col_mx A1 (row' i2 A2)).

Section Block.

Variables m1 m2 n1 n2 : nat.

Building a block matrix from 4 matrices : up left, up right, down left and down right components

Definition block_mx Aul Aur Adl Adr : 'M_(m1 + m2, n1 + n2) :=
  col_mx (row_mx Aul Aur) (row_mx Adl Adr).

Lemma eq_block_mx Aul Aur Adl Adr Bul Bur Bdl Bdr :
 block_mx Aul Aur Adl Adr = block_mx Bul Bur Bdl Bdr
  [/\ Aul = Bul, Aur = Bur, Adl = Bdl & Adr = Bdr].

Lemma block_mx_const a :
  block_mx (const_mx a) (const_mx a) (const_mx a) (const_mx a) = const_mx a.

Section CutBlock.

Variable A : matrix R (m1 + m2) (n1 + n2).

Definition ulsubmx := lsubmx (usubmx A).
Definition ursubmx := rsubmx (usubmx A).
Definition dlsubmx := lsubmx (dsubmx A).
Definition drsubmx := rsubmx (dsubmx A).

Lemma submxK : block_mx ulsubmx ursubmx dlsubmx drsubmx = A.

Lemma ulsubmxEsub : ulsubmx = mxsub (lshift _) (lshift _) A.

Lemma dlsubmxEsub : dlsubmx = mxsub (@rshift _ _) (lshift _) A.

Lemma ursubmxEsub : ursubmx = mxsub (lshift _) (@rshift _ _) A.

Lemma drsubmxEsub : drsubmx = mxsub (@rshift _ _) (@rshift _ _) A.

End CutBlock.

Section CatBlock.

Variables (Aul : 'M[R]_(m1, n1)) (Aur : 'M[R]_(m1, n2)).
Variables (Adl : 'M[R]_(m2, n1)) (Adr : 'M[R]_(m2, n2)).

Let A := block_mx Aul Aur Adl Adr.

Lemma block_mxEul i j : A (lshift m2 i) (lshift n2 j) = Aul i j.
Lemma block_mxKul : ulsubmx A = Aul.

Lemma block_mxEur i j : A (lshift m2 i) (rshift n1 j) = Aur i j.
Lemma block_mxKur : ursubmx A = Aur.

Lemma block_mxEdl i j : A (rshift m1 i) (lshift n2 j) = Adl i j.
Lemma block_mxKdl : dlsubmx A = Adl.

Lemma block_mxEdr i j : A (rshift m1 i) (rshift n1 j) = Adr i j.
Lemma block_mxKdr : drsubmx A = Adr.

Lemma block_mxEv : A = col_mx (row_mx Aul Aur) (row_mx Adl Adr).

End CatBlock.

End Block.

Section TrCutBlock.

Variables m1 m2 n1 n2 : nat.
Variable A : 'M[R]_(m1 + m2, n1 + n2).

Lemma trmx_ulsub : (ulsubmx A)^T = ulsubmx A^T.

Lemma trmx_ursub : (ursubmx A)^T = dlsubmx A^T.

Lemma trmx_dlsub : (dlsubmx A)^T = ursubmx A^T.

Lemma trmx_drsub : (drsubmx A)^T = drsubmx A^T.

End TrCutBlock.

Section TrBlock.
Variables m1 m2 n1 n2 : nat.
Variables (Aul : 'M[R]_(m1, n1)) (Aur : 'M[R]_(m1, n2)).
Variables (Adl : 'M[R]_(m2, n1)) (Adr : 'M[R]_(m2, n2)).

Lemma tr_block_mx :
 (block_mx Aul Aur Adl Adr)^T = block_mx Aul^T Adl^T Aur^T Adr^T.

Lemma block_mxEh :
  block_mx Aul Aur Adl Adr = row_mx (col_mx Aul Adl) (col_mx Aur Adr).
End TrBlock.

This lemma has Prenex Implicits to help RL rewrititng with castmx_sym.
Lemma block_mxA m1 m2 m3 n1 n2 n3
   (A11 : 'M_(m1, n1)) (A12 : 'M_(m1, n2)) (A13 : 'M_(m1, n3))
   (A21 : 'M_(m2, n1)) (A22 : 'M_(m2, n2)) (A23 : 'M_(m2, n3))
   (A31 : 'M_(m3, n1)) (A32 : 'M_(m3, n2)) (A33 : 'M_(m3, n3)) :
  let cast := (esym (addnA m1 m2 m3), esym (addnA n1 n2 n3)) in
  let row1 := row_mx A12 A13 in let col1 := col_mx A21 A31 in
  let row3 := row_mx A31 A32 in let col3 := col_mx A13 A23 in
  block_mx A11 row1 col1 (block_mx A22 A23 A32 A33)
    = castmx cast (block_mx (block_mx A11 A12 A21 A22) col3 row3 A33).
Definition block_mxAx := block_mxA. (* Bypass Prenex Implicits *)

Section Induction.

Lemma row_ind m (P : n, 'M[R]_(m, n) Type) :
    ( A, P 0 A)
    ( n c A, P n A P (1 + n)%N (row_mx c A))
   n A, P n A.

Lemma col_ind n (P : m, 'M[R]_(m, n) Type) :
    ( A, P 0 A)
    ( m r A, P m A P (1 + m)%N (col_mx r A))
   m A, P m A.

Lemma mx_ind (P : m n, 'M[R]_(m, n) Type) :
    ( m A, P m 0 A)
    ( n A, P 0 n A)
    ( m n x r c A, P m n A P (1 + m)%N (1 + n)%N (block_mx x r c A))
   m n A, P m n A.
Definition matrix_rect := mx_ind.
Definition matrix_rec := mx_ind.
Definition matrix_ind := mx_ind.

Lemma sqmx_ind (P : n, 'M[R]_n Type) :
    ( A, P 0 A)
    ( n x r c A, P n A P (1 + n)%N (block_mx x r c A))
   n A, P n A.

Lemma ringmx_ind (P : n, 'M[R]_n.+1 Type) :
    ( x, P 0 x)
    ( n x (r : 'rV_n.+1) (c : 'cV_n.+1) A,
       P n A P (1 + n)%N (block_mx x r c A))
   n A, P n A.

Lemma mxsub_ind
    (weight : m n, 'M[R]_(m, n) nat)
    (sub : m n m' n', ('I_m' 'I_m) ('I_n' 'I_n) Prop)
    (P : m n, 'M[R]_(m, n) Type) :
    ( m n (A : 'M[R]_(m, n)),
      ( m' n' f g, weight m' n' (mxsub f g A) < weight m n A
                         sub m n m' n' f g
                         P m' n' (mxsub f g A)) P m n A)
   m n A, P m n A.

End Induction.

Bijections mxvec : 'M_(m, n) <----> 'rV_(m * n) : vec_mx
Section VecMatrix.

Variables m n : nat.

Lemma mxvec_cast : #|{:'I_m × 'I_n}| = (m × n)%N.

Definition mxvec_index (i : 'I_m) (j : 'I_n) :=
  cast_ord mxvec_cast (enum_rank (i, j)).

Variant is_mxvec_index : 'I_(m × n) Type :=
  isMxvecIndex i j : is_mxvec_index (mxvec_index i j).

Lemma mxvec_indexP k : is_mxvec_index k.

Coercion pair_of_mxvec_index k (i_k : is_mxvec_index k) :=
  let: isMxvecIndex i j := i_k in (i, j).

Definition mxvec (A : 'M[R]_(m, n)) :=
  castmx (erefl _, mxvec_cast) (\row_k A (enum_val k).1 (enum_val k).2).

Fact vec_mx_key : unit.
Definition vec_mx (u : 'rV[R]_(m × n)) :=
  \matrix[vec_mx_key]_(i, j) u 0 (mxvec_index i j).

Lemma mxvecE A i j : mxvec A 0 (mxvec_index i j) = A i j.

Lemma mxvecK : cancel mxvec vec_mx.

Lemma vec_mxK : cancel vec_mx mxvec.

Lemma curry_mxvec_bij : {on 'I_(m × n), bijective (uncurry mxvec_index)}.

End VecMatrix.

End MatrixStructural.

Arguments const_mx {R m n}.
Arguments row_mxA {R m n1 n2 n3 A1 A2 A3}.
Arguments col_mxA {R m1 m2 m3 n A1 A2 A3}.
Arguments block_mxA
  {R m1 m2 m3 n1 n2 n3 A11 A12 A13 A21 A22 A23 A31 A32 A33}.
Arguments trmx_inj {R m n} [A1 A2] eqA12t : rename.

Notation "A ^T" := (trmx A) : ring_scope.
Notation colsub g := (mxsub id g).
Notation rowsub f := (mxsub f id).

Arguments eq_mxsub [R m n m' n' f] f' [g] g' _.
Arguments eq_rowsub [R m n m' f] f' _.
Arguments eq_colsub [R m n n' g] g' _.

Matrix parametricity.
Section MapMatrix.

Variables (aT rT : Type) (f : aT rT).

Fact map_mx_key : unit.
Definition map_mx m n (A : 'M_(m, n)) := \matrix[map_mx_key]_(i, j) f (A i j).

Notation "A ^f" := (map_mx A) : ring_scope.

Section OneMatrix.

Variables (m n : nat) (A : 'M[aT]_(m, n)).

Lemma map_trmx : A^f^T = A^T^f.

Lemma map_const_mx a : (const_mx a)^f = const_mx (f a) :> 'M_(m, n).

Lemma map_row i : (row i A)^f = row i A^f.

Lemma map_col j : (col j A)^f = col j A^f.

Lemma map_row' i0 : (row' i0 A)^f = row' i0 A^f.

Lemma map_col' j0 : (col' j0 A)^f = col' j0 A^f.

Lemma map_mxsub m' n' g h : (@mxsub _ _ _ m' n' g h A)^f = mxsub g h A^f.

Lemma map_row_perm s : (row_perm s A)^f = row_perm s A^f.

Lemma map_col_perm s : (col_perm s A)^f = col_perm s A^f.

Lemma map_xrow i1 i2 : (xrow i1 i2 A)^f = xrow i1 i2 A^f.

Lemma map_xcol j1 j2 : (xcol j1 j2 A)^f = xcol j1 j2 A^f.

Lemma map_castmx m' n' c : (castmx c A)^f = castmx c A^f :> 'M_(m', n').

Lemma map_conform_mx m' n' (B : 'M_(m', n')) :
  (conform_mx B A)^f = conform_mx B^f A^f.

Lemma map_mxvec : (mxvec A)^f = mxvec A^f.

Lemma map_vec_mx (v : 'rV_(m × n)) : (vec_mx v)^f = vec_mx v^f.

End OneMatrix.

Section Block.

Variables m1 m2 n1 n2 : nat.
Variables (Aul : 'M[aT]_(m1, n1)) (Aur : 'M[aT]_(m1, n2)).
Variables (Adl : 'M[aT]_(m2, n1)) (Adr : 'M[aT]_(m2, n2)).
Variables (Bh : 'M[aT]_(m1, n1 + n2)) (Bv : 'M[aT]_(m1 + m2, n1)).
Variable B : 'M[aT]_(m1 + m2, n1 + n2).

Lemma map_row_mx : (row_mx Aul Aur)^f = row_mx Aul^f Aur^f.

Lemma map_col_mx : (col_mx Aul Adl)^f = col_mx Aul^f Adl^f.

Lemma map_block_mx :
  (block_mx Aul Aur Adl Adr)^f = block_mx Aul^f Aur^f Adl^f Adr^f.

Lemma map_lsubmx : (lsubmx Bh)^f = lsubmx Bh^f.

Lemma map_rsubmx : (rsubmx Bh)^f = rsubmx Bh^f.

Lemma map_usubmx : (usubmx Bv)^f = usubmx Bv^f.

Lemma map_dsubmx : (dsubmx Bv)^f = dsubmx Bv^f.

Lemma map_ulsubmx : (ulsubmx B)^f = ulsubmx B^f.

Lemma map_ursubmx : (ursubmx B)^f = ursubmx B^f.

Lemma map_dlsubmx : (dlsubmx B)^f = dlsubmx B^f.

Lemma map_drsubmx : (drsubmx B)^f = drsubmx B^f.

End Block.

End MapMatrix.

Arguments map_mx {aT rT} f {m n} A.

Section MultipleMapMatrix.
Context {R S T : Type} {m n : nat}.
Local Notation "M ^ phi" := (map_mx phi M).

Lemma map_mx_comp (f : R S) (g : S T)
  (M : 'M_(m, n)) : M ^ (g \o f) = (M ^ f) ^ g.

Lemma eq_in_map_mx (g f : R S) (M : 'M_(m, n)) :
  ( i j, f (M i j) = g (M i j)) M ^ f = M ^ g.

Lemma eq_map_mx (g f : R S) : f =1 g
   (M : 'M_(m, n)), M ^ f = M ^ g.

Lemma map_mx_id_in (f : R R) (M : 'M_(m, n)) :
  ( i j, f (M i j) = M i j) M ^ f = M.

Lemma map_mx_id (f : R R) : f =1 id M : 'M_(m, n), M ^ f = M.

End MultipleMapMatrix.
Arguments eq_map_mx {R S m n} g [f].
Arguments eq_in_map_mx {R S m n} g [f M].
Arguments map_mx_id_in {R m n} [f M].
Arguments map_mx_id {R m n} [f].

Matrix lifted laws ******************

Section Map2Matrix.
Context {R S T : Type} (f : R S T).

Fact map2_mx_key : unit.
Definition map2_mx m n (A : 'M_(m, n)) (B : 'M_(m, n)) :=
  \matrix[map2_mx_key]_(i, j) f (A i j) (B i j).

Section OneMatrix.

Variables (m n : nat) (A : 'M[R]_(m, n)) (B : 'M[S]_(m, n)).

Lemma map2_trmx : (map2_mx A B)^T = map2_mx A^T B^T.

Lemma map2_const_mx a b :
  map2_mx (const_mx a) (const_mx b) = const_mx (f a b) :> 'M_(m, n).

Lemma map2_row i : map2_mx (row i A) (row i B) = row i (map2_mx A B).

Lemma map2_col j : map2_mx (col j A) (col j B) = col j (map2_mx A B).

Lemma map2_row' i0 : map2_mx (row' i0 A) (row' i0 B) = row' i0 (map2_mx A B).

Lemma map2_col' j0 : map2_mx (col' j0 A) (col' j0 B) = col' j0 (map2_mx A B).

Lemma map2_mxsub m' n' g h :
  map2_mx (@mxsub _ _ _ m' n' g h A) (@mxsub _ _ _ m' n' g h B) =
  mxsub g h (map2_mx A B).

Lemma map2_row_perm s :
  map2_mx (row_perm s A) (row_perm s B) = row_perm s (map2_mx A B).

Lemma map2_col_perm s :
  map2_mx (col_perm s A) (col_perm s B) = col_perm s (map2_mx A B).

Lemma map2_xrow i1 i2 :
  map2_mx (xrow i1 i2 A) (xrow i1 i2 B) = xrow i1 i2 (map2_mx A B).

Lemma map2_xcol j1 j2 :
  map2_mx (xcol j1 j2 A) (xcol j1 j2 B) = xcol j1 j2 (map2_mx A B).

Lemma map2_castmx m' n' c :
  map2_mx (castmx c A) (castmx c B) = castmx c (map2_mx A B) :> 'M_(m', n').

Lemma map2_conform_mx m' n' (A' : 'M_(m', n')) (B' : 'M_(m', n')) :
  map2_mx (conform_mx A' A) (conform_mx B' B) =
  conform_mx (map2_mx A' B') (map2_mx A B).

Lemma map2_mxvec : map2_mx (mxvec A) (mxvec B) = mxvec (map2_mx A B).

Lemma map2_vec_mx (v : 'rV_(m × n)) (w : 'rV_(m × n)) :
  map2_mx (vec_mx v) (vec_mx w) = vec_mx (map2_mx v w).

End OneMatrix.

Section Block.

Variables m1 m2 n1 n2 : nat.
Variables (Aul : 'M[R]_(m1, n1)) (Aur : 'M[R]_(m1, n2)).
Variables (Adl : 'M[R]_(m2, n1)) (Adr : 'M[R]_(m2, n2)).
Variables (Bh : 'M[R]_(m1, n1 + n2)) (Bv : 'M[R]_(m1 + m2, n1)).
Variable B : 'M[R]_(m1 + m2, n1 + n2).
Variables (A'ul : 'M[S]_(m1, n1)) (A'ur : 'M[S]_(m1, n2)).
Variables (A'dl : 'M[S]_(m2, n1)) (A'dr : 'M[S]_(m2, n2)).
Variables (B'h : 'M[S]_(m1, n1 + n2)) (B'v : 'M[S]_(m1 + m2, n1)).
Variable B' : 'M[S]_(m1 + m2, n1 + n2).

Lemma map2_row_mx :
  map2_mx (row_mx Aul Aur) (row_mx A'ul A'ur) =
  row_mx (map2_mx Aul A'ul) (map2_mx Aur A'ur).

Lemma map2_col_mx :
  map2_mx (col_mx Aul Adl) (col_mx A'ul A'dl) =
  col_mx (map2_mx Aul A'ul) (map2_mx Adl A'dl).

Lemma map2_block_mx :
  map2_mx (block_mx Aul Aur Adl Adr) (block_mx A'ul A'ur A'dl A'dr) =
  block_mx
   (map2_mx Aul A'ul) (map2_mx Aur A'ur) (map2_mx Adl A'dl) (map2_mx Adr A'dr).

Lemma map2_lsubmx : map2_mx (lsubmx Bh) (lsubmx B'h) = lsubmx (map2_mx Bh B'h).

Lemma map2_rsubmx : map2_mx (rsubmx Bh) (rsubmx B'h) = rsubmx (map2_mx Bh B'h).

Lemma map2_usubmx : map2_mx (usubmx Bv) (usubmx B'v) = usubmx (map2_mx Bv B'v).

Lemma map2_dsubmx : map2_mx (dsubmx Bv) (dsubmx B'v) = dsubmx (map2_mx Bv B'v).

Lemma map2_ulsubmx : map2_mx (ulsubmx B) (ulsubmx B') = ulsubmx (map2_mx B B').

Lemma map2_ursubmx : map2_mx (ursubmx B) (ursubmx B') = ursubmx (map2_mx B B').

Lemma map2_dlsubmx : map2_mx (dlsubmx B) (dlsubmx B') = dlsubmx (map2_mx B B').

Lemma map2_drsubmx : map2_mx (drsubmx B) (drsubmx B') = drsubmx (map2_mx B B').

End Block.

End Map2Matrix.

Section Map2Eq.

Context {R S T : Type} {m n : nat}.

Lemma eq_in_map2_mx (f g : R S T) (M : 'M[R]_(m, n)) (M' : 'M[S]_(m, n)) :
  ( i j, f (M i j) (M' i j) = g (M i j) (M' i j))
  map2_mx f M M' = map2_mx g M M'.

Lemma eq_map2_mx (f g : R S T) : f =2 g
  @map2_mx _ _ _ f m n =2 @map2_mx _ _ _ g m n.

Lemma map2_mx_left_in (f : R R R) (M : 'M_(m, n)) (M' : 'M_(m, n)) :
  ( i j, f (M i j) (M' i j) = M i j) map2_mx f M M' = M.

Lemma map2_mx_left (f : R R R) : f =2 (fun x _x)
   (M : 'M_(m, n)) (M' : 'M_(m, n)), map2_mx f M M' = M.

Lemma map2_mx_right_in (f : R R R) (M : 'M_(m, n)) (M' : 'M_(m, n)) :
  ( i j, f (M i j) (M' i j) = M' i j) map2_mx f M M' = M'.

Lemma map2_mx_right (f : R R R) : f =2 (fun _ xx)
   (M : 'M_(m, n)) (M' : 'M_(m, n)), map2_mx f M M' = M'.

End Map2Eq.

Section MatrixLaws.

Context {T : Type} {m n : nat} {idm : T}.

Lemma map2_mxA {opm : Monoid.law idm} : associative (@map2_mx _ _ _ opm m n).

Lemma map2_1mx {opm : Monoid.law idm} :
  left_id (const_mx idm) (@map2_mx _ _ _ opm m n).

Lemma map2_mx1 {opm : Monoid.law idm} :
  right_id (const_mx idm) (@map2_mx _ _ _ opm m n).


Lemma map2_mxC {opm : Monoid.com_law idm} :
  commutative (@map2_mx _ _ _ opm m n).


Lemma map2_0mx {opm : Monoid.mul_law idm} :
  left_zero (const_mx idm) (@map2_mx _ _ _ opm m n).

Lemma map2_mx0 {opm : Monoid.mul_law idm} :
  right_zero (const_mx idm) (@map2_mx _ _ _ opm m n).


Lemma map2_mxDl {mul : T T T} {add : Monoid.add_law idm mul} :
  left_distributive (@map2_mx _ _ _ mul m n) (@map2_mx _ _ _ add m n).

Lemma map2_mxDr {mul : T T T} {add : Monoid.add_law idm mul} :
  right_distributive (@map2_mx _ _ _ mul m n) (@map2_mx _ _ _ add m n).


End MatrixLaws.

Matrix Nmodule (additive abelian monoid) structure ***********

Section MatrixNmodule.

Variable V : nmodType.

Section FixedDim.

Variables m n : nat.
Implicit Types A B : 'M[V]_(m, n).

Fact addmx_key : unit.
Definition addmx := @map2_mx V V V +%R m n.

Definition addmxA : associative addmx := map2_mxA.
Definition addmxC : commutative addmx := map2_mxC.
Definition add0mx : left_id (const_mx 0) addmx := map2_1mx.


Lemma mulmxnE A d i j : (A *+ d) i j = A i j *+ d.

Lemma summxE I r (P : pred I) (E : I 'M_(m, n)) i j :
  (\sum_(k <- r | P k) E k) i j = \sum_(k <- r | P k) E k i j.

Lemma const_mx_is_semi_additive : semi_additive const_mx.

End FixedDim.

Section SemiAdditive.

Variables (m n p q : nat) (f : 'I_p 'I_q 'I_m) (g : 'I_p 'I_q 'I_n).

Definition swizzle_mx k (A : 'M[V]_(m, n)) :=
  \matrix[k]_(i, j) A (f i j) (g i j).

Lemma swizzle_mx_is_semi_additive k : semi_additive (swizzle_mx k).

End SemiAdditive.

Local Notation SwizzleAdd op := (GRing.Additive.copy op (swizzle_mx _ _ _)).


Lemma flatmx0 n : all_equal_to (0 : 'M_(0, n)).

Lemma thinmx0 n : all_equal_to (0 : 'M_(n, 0)).

Lemma trmx0 m n : (0 : 'M_(m, n))^T = 0.

Lemma row0 m n i0 : row i0 (0 : 'M_(m, n)) = 0.

Lemma col0 m n j0 : col j0 (0 : 'M_(m, n)) = 0.

Lemma mxvec_eq0 m n (A : 'M_(m, n)) : (mxvec A == 0) = (A == 0).

Lemma vec_mx_eq0 m n (v : 'rV_(m × n)) : (vec_mx v == 0) = (v == 0).

Lemma row_mx0 m n1 n2 : row_mx 0 0 = 0 :> 'M_(m, n1 + n2).

Lemma col_mx0 m1 m2 n : col_mx 0 0 = 0 :> 'M_(m1 + m2, n).

Lemma block_mx0 m1 m2 n1 n2 : block_mx 0 0 0 0 = 0 :> 'M_(m1 + m2, n1 + n2).

Ltac split_mxE := apply/matrixPi j; do ![rewrite mxE | case: split ⇒ ?].

Lemma add_row_mx m n1 n2 (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)) B1 B2 :
  row_mx A1 A2 + row_mx B1 B2 = row_mx (A1 + B1) (A2 + B2).

Lemma add_col_mx m1 m2 n (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)) B1 B2 :
  col_mx A1 A2 + col_mx B1 B2 = col_mx (A1 + B1) (A2 + B2).

Lemma add_block_mx m1 m2 n1 n2 (Aul : 'M_(m1, n1)) Aur Adl (Adr : 'M_(m2, n2))
                   Bul Bur Bdl Bdr :
  let A := block_mx Aul Aur Adl Adr in let B := block_mx Bul Bur Bdl Bdr in
  A + B = block_mx (Aul + Bul) (Aur + Bur) (Adl + Bdl) (Adr + Bdr).

Lemma row_mx_eq0 (m n1 n2 : nat) (A1 : 'M_(m, n1)) (A2 : 'M_(m, n2)):
  (row_mx A1 A2 == 0) = (A1 == 0) && (A2 == 0).

Lemma col_mx_eq0 (m1 m2 n : nat) (A1 : 'M_(m1, n)) (A2 : 'M_(m2, n)):
  (col_mx A1 A2 == 0) = (A1 == 0) && (A2 == 0).

Lemma block_mx_eq0 m1 m2 n1 n2 (Aul : 'M_(m1, n1)) Aur Adl (Adr : 'M_(m2, n2)) :
  (block_mx Aul Aur Adl Adr == 0) =
  [&& Aul == 0, Aur == 0, Adl == 0 & Adr == 0].

Lemma trmx_eq0 m n (A : 'M_(m, n)) : (A^T == 0) = (A == 0).

Lemma matrix_eq0 m n (A : 'M_(m, n)) :
  (A == 0) = [ i, j, A i j == 0].

Lemma matrix0Pn m n (A : 'M_(m, n)) : reflect ( i j, A i j != 0) (A != 0).

Lemma rV0Pn n (v : 'rV_n) : reflect ( i, v 0 i != 0) (v != 0).

Lemma cV0Pn n (v : 'cV_n) : reflect ( i, v i 0 != 0) (v != 0).

Definition nz_row m n (A : 'M_(m, n)) :=
  oapp (fun irow i A) 0 [pick i | row i A != 0].

Lemma nz_row_eq0 m n (A : 'M_(m, n)) : (nz_row A == 0) = (A == 0).

Definition is_diag_mx m n (A : 'M[V]_(m, n)) :=
  [ i : 'I__, j : 'I__, (i != j :> nat) ==> (A i j == 0)].

Lemma is_diag_mxP m n (A : 'M[V]_(m, n)) :
  reflect ( i j : 'I__, i != j :> nat A i j = 0) (is_diag_mx A).

Lemma mx0_is_diag m n : is_diag_mx (0 : 'M[V]_(m, n)).

Lemma mx11_is_diag (M : 'M_1) : is_diag_mx M.

Definition is_trig_mx m n (A : 'M[V]_(m, n)) :=
  [ i : 'I__, j : 'I__, (i < j)%N ==> (A i j == 0)].

Lemma is_trig_mxP m n (A : 'M[V]_(m, n)) :
  reflect ( i j : 'I__, (i < j)%N A i j = 0) (is_trig_mx A).

Lemma is_diag_mx_is_trig m n (A : 'M[V]_(m, n)) : is_diag_mx A is_trig_mx A.

Lemma mx0_is_trig m n : is_trig_mx (0 : 'M[V]_(m, n)).

Lemma mx11_is_trig (M : 'M_1) : is_trig_mx M.

Lemma is_diag_mxEtrig m n (A : 'M[V]_(m, n)) :
  is_diag_mx A = is_trig_mx A && is_trig_mx A^T.

Lemma is_diag_trmx m n (A : 'M[V]_(m, n)) : is_diag_mx A^T = is_diag_mx A.

Lemma ursubmx_trig m1 m2 n1 n2 (A : 'M[V]_(m1 + m2, n1 + n2)) :
  m1 n1 is_trig_mx A ursubmx A = 0.

Lemma dlsubmx_diag m1 m2 n1 n2 (A : 'M[V]_(m1 + m2, n1 + n2)) :
  n1 m1 is_diag_mx A dlsubmx A = 0.

Lemma ulsubmx_trig m1 m2 n1 n2 (A : 'M[V]_(m1 + m2, n1 + n2)) :
  is_trig_mx A is_trig_mx (ulsubmx A).

Lemma drsubmx_trig m1 m2 n1 n2 (A : 'M[V]_(m1 + m2, n1 + n2)) :
  m1 n1 is_trig_mx A is_trig_mx (drsubmx A).

Lemma ulsubmx_diag m1 m2 n1 n2 (A : 'M[V]_(m1 + m2, n1 + n2)) :
  is_diag_mx A is_diag_mx (ulsubmx A).

Lemma drsubmx_diag m1 m2 n1 n2 (A : 'M[V]_(m1 + m2, n1 + n2)) :
  m1 = n1 is_diag_mx A is_diag_mx (drsubmx A).

Lemma is_trig_block_mx m1 m2 n1 n2 ul ur dl dr : m1 = n1
  @is_trig_mx (m1 + m2) (n1 + n2) (block_mx ul ur dl dr) =
  [&& ur == 0, is_trig_mx ul & is_trig_mx dr].

Lemma trigmx_ind (P : m n, 'M_(m, n) Type) :
  ( m, P m 0 0)
  ( n, P 0 n 0)
  ( m n x c A, is_trig_mx A
    P m n A P (1 + m)%N (1 + n)%N (block_mx x 0 c A))
   m n A, is_trig_mx A P m n A.

Lemma trigsqmx_ind (P : n, 'M[V]_n Type) : (P 0 0)
  ( n x c A, is_trig_mx A P n A P (1 + n)%N (block_mx x 0 c A))
   n A, is_trig_mx A P n A.

Lemma is_diag_block_mx m1 m2 n1 n2 ul ur dl dr : m1 = n1
  @is_diag_mx (m1 + m2) (n1 + n2) (block_mx ul ur dl dr) =
  [&& ur == 0, dl == 0, is_diag_mx ul & is_diag_mx dr].

Lemma diagmx_ind (P : m n, 'M_(m, n) Type) :
  ( m, P m 0 0)
  ( n, P 0 n 0)
  ( m n x c A, is_diag_mx A
    P m n A P (1 + m)%N (1 + n)%N (block_mx x 0 c A))
   m n A, is_diag_mx A P m n A.

Lemma diagsqmx_ind (P : n, 'M[V]_n Type) :
    (P 0 0)
  ( n x c A, is_diag_mx A P n A P (1 + n)%N (block_mx x 0 c A))
   n A, is_diag_mx A P n A.

Diagonal matrices