25 #ifndef OPM_TRIDIAGONAL_MATRIX_HH
26 #define OPM_TRIDIAGONAL_MATRIX_HH
47 template <
class Scalar>
57 {
return matrix_.at(rowIdx_, colIdx); }
60 {
return matrix_.at(rowIdx_, colIdx); }
65 TridiagRow_ &operator++()
66 { ++ rowIdx_;
return *
this; }
71 TridiagRow_ &operator--()
72 { -- rowIdx_;
return *
this; }
77 bool operator==(
const TridiagRow_ &other)
const
78 {
return other.rowIdx_ == rowIdx_ && &other.matrix_ == &matrix_; }
83 bool operator!=(
const TridiagRow_ &other)
const
84 {
return !operator==(other); }
102 mutable size_t rowIdx_;
133 {
return diag_[0].size(); }
155 for (
int diagIdx = 0; diagIdx < 3; ++ diagIdx)
162 Scalar &
at(
size_t rowIdx,
size_t colIdx)
168 if (rowIdx == 0 && colIdx == n - 1)
170 if (rowIdx == n - 1 && colIdx == 0)
171 return diag_[0][n - 1];
174 size_t diagIdx = 1 + colIdx - rowIdx;
176 assert(0 <= diagIdx && diagIdx < 3);
177 return diag_[diagIdx][colIdx];
183 Scalar
at(
size_t rowIdx,
size_t colIdx)
const
188 if (rowIdx == 0 && colIdx == n - 1)
190 if (rowIdx == n - 1 && colIdx == 0)
191 return diag_[0][n - 1];
193 int diagIdx = 1 + colIdx - rowIdx;
195 assert(0 <= diagIdx && diagIdx < 3);
196 return diag_[diagIdx][colIdx];
204 for (
int diagIdx = 0; diagIdx < 3; ++ diagIdx)
205 diag_[diagIdx] = source.diag_[diagIdx];
215 for (
int diagIdx = 0; diagIdx < 3; ++ diagIdx)
216 diag_[diagIdx].assign(
size(), value);
225 {
return TridiagRow_(*
this, 0); }
231 {
return TridiagRow_(const_cast<TridiagonalMatrix&>(*
this), 0); }
236 const_iterator
end()
const
237 {
return TridiagRow_(const_cast<TridiagonalMatrix&>(*
this),
size()); }
243 {
return TridiagRow_(*
this, rowIdx); }
249 {
return TridiagRow_(*
this, rowIdx); }
257 for (
int diagIdx = 0; diagIdx < 3; ++ diagIdx) {
258 for (
int i = 0; i < n; ++i) {
259 diag_[diagIdx][i] *= alpha;
273 for (
int diagIdx = 0; diagIdx < 3; ++ diagIdx) {
274 for (
int i = 0; i < n; ++i) {
275 diag_[diagIdx][i] *= alpha;
286 {
return axpy(-1.0, other); }
292 {
return axpy(1.0, other); }
313 for (
int diagIdx = 0; diagIdx < 3; ++ diagIdx)
314 for (
int i = 0; i < n; ++ i)
315 diag_[diagIdx][i] += alpha * other[diagIdx][i];
332 template<
class Vector>
333 void mv(
const Vector &source, Vector &dest)
const
335 assert(source.size() ==
size());
336 assert(dest.size() ==
size());
341 for (
int i = 1; i < n - 1; ++ i) {
343 diag_[0][i - 1]*source[i-1] +
344 diag_[1][i]*source[i] +
345 diag_[2][i + 1]*source[i + 1];
350 diag_[1][0]*source[0] +
351 diag_[2][1]*source[1] +
352 diag_[2][0]*source[n - 1];
355 diag_[0][n-1]*source[0] +
356 diag_[0][n-2]*source[n-2] +
357 diag_[1][n-1]*source[n-1];
372 template<
class Vector>
373 void umv(
const Vector &source, Vector &dest)
const
375 assert(source.size() ==
size());
376 assert(dest.size() ==
size());
381 for (
int i = 1; i < n - 1; ++ i) {
383 diag_[0][i - 1]*source[i-1] +
384 diag_[1][i]*source[i] +
385 diag_[2][i + 1]*source[i + 1];
390 diag_[1][0]*source[0] +
391 diag_[2][1]*source[1] +
392 diag_[2][0]*source[n - 1];
395 diag_[0][n-1]*source[0] +
396 diag_[0][n-2]*source[n-2] +
397 diag_[1][n-1]*source[n-1];
412 template<
class Vector>
413 void mmv(
const Vector &source, Vector &dest)
const
415 assert(source.size() ==
size());
416 assert(dest.size() ==
size());
421 for (
int i = 1; i < n - 1; ++ i) {
423 diag_[0][i - 1]*source[i-1] +
424 diag_[1][i]*source[i] +
425 diag_[2][i + 1]*source[i + 1];
430 diag_[1][0]*source[0] +
431 diag_[2][1]*source[1] +
432 diag_[2][0]*source[n - 1];
435 diag_[0][n-1]*source[0] +
436 diag_[0][n-2]*source[n-2] +
437 diag_[1][n-1]*source[n-1];
452 template<
class Vector>
453 void usmv(Scalar alpha,
const Vector &source, Vector &dest)
const
455 assert(source.size() ==
size());
456 assert(dest.size() ==
size());
461 for (
int i = 1; i < n - 1; ++ i) {
464 diag_[0][i - 1]*source[i-1] +
465 diag_[1][i]*source[i] +
466 diag_[2][i + 1]*source[i + 1]);
472 diag_[1][0]*source[0] +
473 diag_[2][1]*source[1] +
474 diag_[2][0]*source[n - 1]);
478 diag_[0][n-1]*source[0] +
479 diag_[0][n-2]*source[n-2] +
480 diag_[1][n-1]*source[n-1]);
495 template<
class Vector>
496 void mtv(
const Vector &source, Vector &dest)
const
498 assert(source.size() ==
size());
499 assert(dest.size() ==
size());
504 for (
int i = 1; i < n - 1; ++ i) {
506 diag_[2][i + 1]*source[i-1] +
507 diag_[1][i]*source[i] +
508 diag_[0][i - 1]*source[i + 1];
513 diag_[1][0]*source[0] +
514 diag_[0][1]*source[1] +
515 diag_[0][n-1]*source[n - 1];
518 diag_[2][0]*source[0] +
519 diag_[2][n-1]*source[n-2] +
520 diag_[1][n-1]*source[n-1];
535 template<
class Vector>
536 void umtv(
const Vector &source, Vector &dest)
const
538 assert(source.size() ==
size());
539 assert(dest.size() ==
size());
544 for (
int i = 1; i < n - 1; ++ i) {
546 diag_[2][i + 1]*source[i-1] +
547 diag_[1][i]*source[i] +
548 diag_[0][i - 1]*source[i + 1];
553 diag_[1][0]*source[0] +
554 diag_[0][1]*source[1] +
555 diag_[0][n-1]*source[n - 1];
558 diag_[2][0]*source[0] +
559 diag_[2][n-1]*source[n-2] +
560 diag_[1][n-1]*source[n-1];
575 template<
class Vector>
576 void mmtv (
const Vector &source, Vector &dest)
const
578 assert(source.size() ==
size());
579 assert(dest.size() ==
size());
584 for (
int i = 1; i < n - 1; ++ i) {
586 diag_[2][i + 1]*source[i-1] +
587 diag_[1][i]*source[i] +
588 diag_[0][i - 1]*source[i + 1];
593 diag_[1][0]*source[0] +
594 diag_[0][1]*source[1] +
595 diag_[0][n-1]*source[n - 1];
598 diag_[2][0]*source[0] +
599 diag_[2][n-1]*source[n-2] +
600 diag_[1][n-1]*source[n-1];
615 template<
class Vector>
616 void usmtv(Scalar alpha,
const Vector &source, Vector &dest)
const
618 assert(source.size() ==
size());
619 assert(dest.size() ==
size());
624 for (
int i = 1; i < n - 1; ++ i) {
627 diag_[2][i + 1]*source[i-1] +
628 diag_[1][i]*source[i] +
629 diag_[0][i - 1]*source[i + 1]);
635 diag_[1][0]*source[0] +
636 diag_[0][1]*source[1] +
637 diag_[0][n-1]*source[n - 1]);
641 diag_[2][0]*source[0] +
642 diag_[2][n-1]*source[n-2] +
643 diag_[1][n-1]*source[n-1]);
665 for (
int i = 0; i < n; ++ i)
666 for (
int diagIdx = 0; diagIdx < 3; ++ diagIdx)
667 result += diag_[diagIdx][i];
683 for (
int i = 1; i < n - 1; ++ i) {
711 template <
class XVector,
class BVector>
712 void solve(XVector &x,
const BVector &b)
const
715 solveWithUpperRight_(x, b);
717 solveWithoutUpperRight_(x, b);
723 void print(std::ostream &os = std::cout)
const
728 os <<
at(0, 0) <<
"\t"
738 for (
int rowIdx = 1; rowIdx < n-1; ++rowIdx) {
744 os <<
at(rowIdx, rowIdx - 1) <<
"\t"
745 <<
at(rowIdx, rowIdx) <<
"\t"
746 <<
at(rowIdx, rowIdx + 1) <<
"\n";
751 os <<
at(n-1, 0) <<
"\t";
756 os <<
at(n-1, n-2) <<
"\t"
757 <<
at(n-1, n-1) <<
"\n";
761 template <
class XVector,
class BVector>
762 void solveWithUpperRight_(XVector &x,
const BVector &b)
const
766 std::vector<Scalar> lowerDiag(diag_[0]), mainDiag(diag_[1]), upperDiag(diag_[2]), lastColumn(n);
767 std::vector<Scalar> bStar(n);
768 std::copy(b.begin(), b.end(), bStar.begin());
770 lastColumn[0] = upperDiag[0];
773 for (
size_t i = 1; i < n; ++i) {
774 Scalar alpha = lowerDiag[i - 1]/mainDiag[i - 1];
776 lowerDiag[i - 1] -= alpha * mainDiag[i - 1];
777 mainDiag[i] -= alpha * upperDiag[i];
779 bStar[i] -= alpha * bStar[i - 1];
783 if (lowerDiag[n - 1] != 0.0 &&
size() > 2) {
784 Scalar lastRow = lowerDiag[n - 1];
785 for (
size_t i = 0; i < n - 1; ++i) {
786 Scalar alpha = lastRow/mainDiag[i];
787 lastRow = - alpha*upperDiag[i + 1];
788 bStar[n - 1] -= alpha * bStar[i];
791 mainDiag[n-1] += lastRow;
795 x[n - 1] = bStar[n - 1]/mainDiag[n-1];
796 for (
int i = static_cast<int>(n) - 2; i >= 0; --i) {
797 unsigned iu =
static_cast<unsigned>(i);
798 x[iu] = (bStar[iu] - x[iu + 1]*upperDiag[iu+1] - x[n-1]*lastColumn[iu])/mainDiag[iu];
802 template <
class XVector,
class BVector>
803 void solveWithoutUpperRight_(XVector &x,
const BVector &b)
const
807 std::vector<Scalar> lowerDiag(diag_[0]), mainDiag(diag_[1]), upperDiag(diag_[2]);
808 std::vector<Scalar> bStar(n);
809 std::copy(b.begin(), b.end(), bStar.begin());
812 for (
size_t i = 1; i < n; ++i) {
813 Scalar alpha = lowerDiag[i - 1]/mainDiag[i - 1];
815 lowerDiag[i - 1] -= alpha * mainDiag[i - 1];
816 mainDiag[i] -= alpha * upperDiag[i];
818 bStar[i] -= alpha * bStar[i - 1];
822 if (lowerDiag[n - 1] != 0.0 &&
size() > 2) {
823 Scalar lastRow = lowerDiag[n - 1];
824 for (
size_t i = 0; i < n - 1; ++i) {
825 Scalar alpha = lastRow/mainDiag[i];
826 lastRow = - alpha*upperDiag[i + 1];
827 bStar[n - 1] -= alpha * bStar[i];
830 mainDiag[n-1] += lastRow;
834 x[n - 1] = bStar[n - 1]/mainDiag[n-1];
835 for (
int i = static_cast<int>(n) - 2; i >= 0; --i) {
836 unsigned iu =
static_cast<unsigned>(i);
837 x[iu] = (bStar[iu] - x[iu + 1]*upperDiag[iu+1])/mainDiag[iu];
841 mutable std::vector<Scalar> diag_[3];
846 template <
class Scalar>
847 std::ostream &operator<<(std::ostream &os, const Opm::TridiagonalMatrix<Scalar> &mat)
void solve(XVector &x, const BVector &b) const
Calculate the solution for a linear system of equations.
Definition: TridiagonalMatrix.hpp:712
Scalar FieldType
Definition: TridiagonalMatrix.hpp:106
Provides a tridiagonal matrix that also supports non-zero entries in the upper right and lower left...
Definition: TridiagonalMatrix.hpp:48
TridiagonalMatrix(size_t numRows=0)
Definition: TridiagonalMatrix.hpp:112
Evaluation< Scalar, VarSetTag, numVars > operator*(const ScalarA &a, const Evaluation< Scalar, VarSetTag, numVars > &b)
Definition: Evaluation.hpp:403
const TridiagRow_ operator[](size_t rowIdx) const
Row access operator.
Definition: TridiagonalMatrix.hpp:248
Definition: Air_Mesitylene.hpp:31
TridiagonalMatrix & operator*=(Scalar alpha)
Multiplication with a Scalar.
Definition: TridiagonalMatrix.hpp:254
TridiagonalMatrix(const TridiagonalMatrix &source)
Copy constructor.
Definition: TridiagonalMatrix.hpp:126
Scalar frobeniusNormSquared() const
Calculate the squared frobenius norm.
Definition: TridiagonalMatrix.hpp:660
void umtv(const Vector &source, Vector &dest) const
Transposed additive matrix-vector product.
Definition: TridiagonalMatrix.hpp:536
void mv(const Vector &source, Vector &dest) const
Matrix-vector product.
Definition: TridiagonalMatrix.hpp:333
void mtv(const Vector &source, Vector &dest) const
Transposed matrix-vector product.
Definition: TridiagonalMatrix.hpp:496
Evaluation< Scalar, VarSetTag, numVars > max(const Evaluation< Scalar, VarSetTag, numVars > &x1, const Evaluation< Scalar, VarSetTag, numVars > &x2)
Definition: Math.hpp:114
Scalar & at(size_t rowIdx, size_t colIdx)
Access an entry.
Definition: TridiagonalMatrix.hpp:162
Evaluation< Scalar, VarSetTag, numVars > sqrt(const Evaluation< Scalar, VarSetTag, numVars > &x)
Definition: Math.hpp:278
void mmv(const Vector &source, Vector &dest) const
Subtractive matrix-vector product.
Definition: TridiagonalMatrix.hpp:413
iterator begin()
Definition: TridiagonalMatrix.hpp:224
void usmtv(Scalar alpha, const Vector &source, Vector &dest) const
Transposed scaled additive matrix-vector product.
Definition: TridiagonalMatrix.hpp:616
void resize(size_t n)
Change the number of rows of the matrix.
Definition: TridiagonalMatrix.hpp:150
TridiagonalMatrix & operator=(const TridiagonalMatrix &source)
Assignment operator from another tridiagonal matrix.
Definition: TridiagonalMatrix.hpp:202
void print(std::ostream &os=std::cout) const
Print the matrix to a given output stream.
Definition: TridiagonalMatrix.hpp:723
bool operator!=(const ScalarA &a, const Evaluation< Scalar, VarSetTag, numVars > &b)
Definition: Evaluation.hpp:362
size_t cols() const
Return the number of columns of the matrix.
Definition: TridiagonalMatrix.hpp:144
Evaluation< Scalar, VarSetTag, numVars > abs(const Evaluation< Scalar, VarSetTag, numVars > &)
Definition: Math.hpp:41
TridiagonalMatrix & operator+=(const TridiagonalMatrix &other)
Addition operator.
Definition: TridiagonalMatrix.hpp:291
Scalar infinityNorm() const
Calculate the infinity norm.
Definition: TridiagonalMatrix.hpp:677
TridiagRow_ RowType
Definition: TridiagonalMatrix.hpp:107
void mmtv(const Vector &source, Vector &dest) const
Transposed subtractive matrix-vector product.
Definition: TridiagonalMatrix.hpp:576
const_iterator end() const
Definition: TridiagonalMatrix.hpp:236
TridiagonalMatrix & operator-=(const TridiagonalMatrix &other)
Subtraction operator.
Definition: TridiagonalMatrix.hpp:285
TridiagonalMatrix & operator/=(Scalar alpha)
Division by a Scalar.
Definition: TridiagonalMatrix.hpp:269
size_t size() const
Return the number of rows/columns of the matrix.
Definition: TridiagonalMatrix.hpp:132
Scalar frobeniusNorm() const
Calculate the frobenius norm.
Definition: TridiagonalMatrix.hpp:652
TridiagonalMatrix & operator=(Scalar value)
Assignment operator from a Scalar.
Definition: TridiagonalMatrix.hpp:213
void umv(const Vector &source, Vector &dest) const
Additive matrix-vector product.
Definition: TridiagonalMatrix.hpp:373
TridiagRow_ const_iterator
Definition: TridiagonalMatrix.hpp:110
TridiagonalMatrix(size_t numRows, Scalar value)
Definition: TridiagonalMatrix.hpp:117
size_t rows() const
Return the number of rows of the matrix.
Definition: TridiagonalMatrix.hpp:138
TridiagonalMatrix & axpy(Scalar alpha, const TridiagonalMatrix &other)
Multiply and add the matrix entries of another tridiagonal matrix.
Definition: TridiagonalMatrix.hpp:308
void usmv(Scalar alpha, const Vector &source, Vector &dest) const
Scaled additive matrix-vector product.
Definition: TridiagonalMatrix.hpp:453
Scalar at(size_t rowIdx, size_t colIdx) const
Access an entry.
Definition: TridiagonalMatrix.hpp:183
size_t SizeType
Definition: TridiagonalMatrix.hpp:108
TridiagRow_ operator[](size_t rowIdx)
Row access operator.
Definition: TridiagonalMatrix.hpp:242
const_iterator begin() const
Definition: TridiagonalMatrix.hpp:230
TridiagRow_ iterator
Definition: TridiagonalMatrix.hpp:109