25 #ifndef EIGEN_MATRIX_FUNCTION
26 #define EIGEN_MATRIX_FUNCTION
28 #include "StemFunction.h"
29 #include "MatrixFunctionAtomic.h"
49 template <
typename MatrixType,
51 int IsComplex = NumTraits<typename internal::traits<MatrixType>::Scalar>::IsComplex>
74 template <
typename ResultType>
75 void compute(ResultType &result);
82 template <
typename MatrixType,
typename AtomicType>
87 typedef internal::traits<MatrixType> Traits;
88 typedef typename Traits::Scalar Scalar;
89 static const int Rows = Traits::RowsAtCompileTime;
90 static const int Cols = Traits::ColsAtCompileTime;
91 static const int Options = MatrixType::Options;
92 static const int MaxRows = Traits::MaxRowsAtCompileTime;
93 static const int MaxCols = Traits::MaxColsAtCompileTime;
95 typedef std::complex<Scalar> ComplexScalar;
96 typedef Matrix<ComplexScalar, Rows, Cols, Options, MaxRows, MaxCols> ComplexMatrix;
105 MatrixFunction(
const MatrixType& A, AtomicType& atomic) : m_A(A), m_atomic(atomic) { }
116 template <
typename ResultType>
117 void compute(ResultType& result)
119 ComplexMatrix CA = m_A.template cast<ComplexScalar>();
120 ComplexMatrix Cresult;
121 MatrixFunction<ComplexMatrix, AtomicType> mf(CA, m_atomic);
123 result = Cresult.real();
127 typename internal::nested<MatrixType>::type m_A;
128 AtomicType& m_atomic;
137 template <
typename MatrixType,
typename AtomicType>
138 class MatrixFunction<MatrixType, AtomicType, 1>
142 typedef internal::traits<MatrixType> Traits;
143 typedef typename MatrixType::Scalar Scalar;
144 typedef typename MatrixType::Index Index;
145 static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
146 static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
147 static const int Options = MatrixType::Options;
148 typedef typename NumTraits<Scalar>::Real RealScalar;
149 typedef Matrix<Scalar, Traits::RowsAtCompileTime, 1> VectorType;
150 typedef Matrix<Index, Traits::RowsAtCompileTime, 1> IntVectorType;
151 typedef Matrix<Index, Dynamic, 1> DynamicIntVectorType;
152 typedef std::list<Scalar> Cluster;
153 typedef std::list<Cluster> ListOfClusters;
154 typedef Matrix<Scalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
159 template <
typename ResultType>
void compute(ResultType& result);
163 void computeSchurDecomposition();
164 void partitionEigenvalues();
165 typename ListOfClusters::iterator findCluster(Scalar key);
166 void computeClusterSize();
167 void computeBlockStart();
168 void constructPermutation();
170 void swapEntriesInSchur(Index index);
171 void computeBlockAtomic();
172 Block<MatrixType>
block(MatrixType& A, Index i, Index j);
173 void computeOffDiagonal();
174 DynMatrixType solveTriangularSylvester(
const DynMatrixType& A,
const DynMatrixType& B,
const DynMatrixType& C);
176 typename internal::nested<MatrixType>::type m_A;
177 AtomicType& m_atomic;
181 ListOfClusters m_clusters;
182 DynamicIntVectorType m_eivalToCluster;
183 DynamicIntVectorType m_clusterSize;
184 DynamicIntVectorType m_blockStart;
185 IntVectorType m_permutation;
193 static const RealScalar separation() {
return static_cast<RealScalar
>(0.1); }
203 template <
typename MatrixType,
typename AtomicType>
205 : m_A(A), m_atomic(atomic)
215 template <
typename MatrixType,
typename AtomicType>
216 template <
typename ResultType>
219 computeSchurDecomposition();
220 partitionEigenvalues();
221 computeClusterSize();
223 constructPermutation();
225 computeBlockAtomic();
226 computeOffDiagonal();
227 result = m_U * m_fT * m_U.adjoint();
231 template <
typename MatrixType,
typename AtomicType>
232 void MatrixFunction<MatrixType,AtomicType,1>::computeSchurDecomposition()
234 const ComplexSchur<MatrixType> schurOfA(m_A);
235 m_T = schurOfA.matrixT();
236 m_U = schurOfA.matrixU();
250 template <
typename MatrixType,
typename AtomicType>
251 void MatrixFunction<MatrixType,AtomicType,1>::partitionEigenvalues()
253 const Index rows = m_T.rows();
254 VectorType diag = m_T.diagonal();
256 for (Index i=0; i<rows; ++i) {
258 typename ListOfClusters::iterator qi = findCluster(diag(i));
259 if (qi == m_clusters.end()) {
261 l.push_back(diag(i));
262 m_clusters.push_back(l);
263 qi = m_clusters.end();
268 for (Index j=i+1; j<rows; ++j) {
269 if (internal::abs(diag(j) - diag(i)) <= separation() && std::find(qi->begin(), qi->end(), diag(j)) == qi->end()) {
270 typename ListOfClusters::iterator qj = findCluster(diag(j));
271 if (qj == m_clusters.end()) {
272 qi->push_back(diag(j));
274 qi->insert(qi->end(), qj->begin(), qj->end());
275 m_clusters.erase(qj);
287 template <
typename MatrixType,
typename AtomicType>
288 typename MatrixFunction<MatrixType,AtomicType,1>::ListOfClusters::iterator MatrixFunction<MatrixType,AtomicType,1>::findCluster(Scalar key)
290 typename Cluster::iterator j;
291 for (
typename ListOfClusters::iterator i = m_clusters.begin(); i != m_clusters.end(); ++i) {
292 j = std::find(i->begin(), i->end(), key);
296 return m_clusters.end();
300 template <
typename MatrixType,
typename AtomicType>
301 void MatrixFunction<MatrixType,AtomicType,1>::computeClusterSize()
303 const Index rows = m_T.rows();
304 VectorType diag = m_T.diagonal();
305 const Index numClusters =
static_cast<Index
>(m_clusters.size());
307 m_clusterSize.setZero(numClusters);
308 m_eivalToCluster.resize(rows);
309 Index clusterIndex = 0;
310 for (
typename ListOfClusters::const_iterator cluster = m_clusters.begin(); cluster != m_clusters.end(); ++cluster) {
311 for (Index i = 0; i < diag.rows(); ++i) {
312 if (std::find(cluster->begin(), cluster->end(), diag(i)) != cluster->end()) {
313 ++m_clusterSize[clusterIndex];
314 m_eivalToCluster[i] = clusterIndex;
322 template <
typename MatrixType,
typename AtomicType>
323 void MatrixFunction<MatrixType,AtomicType,1>::computeBlockStart()
325 m_blockStart.resize(m_clusterSize.rows());
327 for (Index i = 1; i < m_clusterSize.rows(); i++) {
328 m_blockStart(i) = m_blockStart(i-1) + m_clusterSize(i-1);
333 template <
typename MatrixType,
typename AtomicType>
334 void MatrixFunction<MatrixType,AtomicType,1>::constructPermutation()
336 DynamicIntVectorType indexNextEntry = m_blockStart;
337 m_permutation.resize(m_T.rows());
338 for (Index i = 0; i < m_T.rows(); i++) {
339 Index cluster = m_eivalToCluster[i];
340 m_permutation[i] = indexNextEntry[cluster];
341 ++indexNextEntry[cluster];
346 template <
typename MatrixType,
typename AtomicType>
347 void MatrixFunction<MatrixType,AtomicType,1>::permuteSchur()
349 IntVectorType p = m_permutation;
350 for (Index i = 0; i < p.rows() - 1; i++) {
352 for (j = i; j < p.rows(); j++) {
353 if (p(j) == i)
break;
355 eigen_assert(p(j) == i);
356 for (Index k = j-1; k >= i; k--) {
357 swapEntriesInSchur(k);
358 std::swap(p.coeffRef(k), p.coeffRef(k+1));
364 template <
typename MatrixType,
typename AtomicType>
365 void MatrixFunction<MatrixType,AtomicType,1>::swapEntriesInSchur(Index index)
367 JacobiRotation<Scalar> rotation;
368 rotation.makeGivens(m_T(index, index+1), m_T(index+1, index+1) - m_T(index, index));
369 m_T.applyOnTheLeft(index, index+1, rotation.adjoint());
370 m_T.applyOnTheRight(index, index+1, rotation);
371 m_U.applyOnTheRight(index, index+1, rotation);
380 template <
typename MatrixType,
typename AtomicType>
381 void MatrixFunction<MatrixType,AtomicType,1>::computeBlockAtomic()
383 m_fT.resize(m_T.rows(), m_T.cols());
385 for (Index i = 0; i < m_clusterSize.rows(); ++i) {
386 block(m_fT, i, i) = m_atomic.compute(
block(m_T, i, i));
391 template <
typename MatrixType,
typename AtomicType>
392 Block<MatrixType> MatrixFunction<MatrixType,AtomicType,1>::block(MatrixType& A, Index i, Index j)
394 return A.block(m_blockStart(i), m_blockStart(j), m_clusterSize(i), m_clusterSize(j));
404 template <
typename MatrixType,
typename AtomicType>
405 void MatrixFunction<MatrixType,AtomicType,1>::computeOffDiagonal()
407 for (Index diagIndex = 1; diagIndex < m_clusterSize.rows(); diagIndex++) {
408 for (Index blockIndex = 0; blockIndex < m_clusterSize.rows() - diagIndex; blockIndex++) {
410 DynMatrixType A =
block(m_T, blockIndex, blockIndex);
411 DynMatrixType B = -
block(m_T, blockIndex+diagIndex, blockIndex+diagIndex);
412 DynMatrixType C =
block(m_fT, blockIndex, blockIndex) *
block(m_T, blockIndex, blockIndex+diagIndex);
413 C -=
block(m_T, blockIndex, blockIndex+diagIndex) *
block(m_fT, blockIndex+diagIndex, blockIndex+diagIndex);
414 for (Index k = blockIndex + 1; k < blockIndex + diagIndex; k++) {
415 C +=
block(m_fT, blockIndex, k) *
block(m_T, k, blockIndex+diagIndex);
416 C -=
block(m_T, blockIndex, k) *
block(m_fT, k, blockIndex+diagIndex);
418 block(m_fT, blockIndex, blockIndex+diagIndex) = solveTriangularSylvester(A, B, C);
446 template <
typename MatrixType,
typename AtomicType>
447 typename MatrixFunction<MatrixType,AtomicType,1>::DynMatrixType MatrixFunction<MatrixType,AtomicType,1>::solveTriangularSylvester(
448 const DynMatrixType& A,
449 const DynMatrixType& B,
450 const DynMatrixType& C)
452 eigen_assert(A.rows() == A.cols());
453 eigen_assert(A.isUpperTriangular());
454 eigen_assert(B.rows() == B.cols());
455 eigen_assert(B.isUpperTriangular());
456 eigen_assert(C.rows() == A.rows());
457 eigen_assert(C.cols() == B.rows());
461 DynMatrixType X(m, n);
463 for (Index i = m - 1; i >= 0; --i) {
464 for (Index j = 0; j < n; ++j) {
471 Matrix<Scalar,1,1> AXmatrix = A.row(i).tail(m-1-i) * X.col(j).tail(m-1-i);
480 Matrix<Scalar,1,1> XBmatrix = X.row(i).head(j) * B.col(j).head(j);
484 X(i,j) = (C(i,j) - AX - XB) / (A(i,i) + B(j,j));
503 :
public ReturnByValue<MatrixFunctionReturnValue<Derived> >
507 typedef typename Derived::Scalar Scalar;
508 typedef typename Derived::Index Index;
509 typedef typename internal::stem_function<Scalar>::type StemFunction;
524 template <
typename ResultType>
525 inline void evalTo(ResultType& result)
const
527 typedef typename Derived::PlainObject PlainObject;
528 typedef internal::traits<PlainObject> Traits;
529 static const int RowsAtCompileTime = Traits::RowsAtCompileTime;
530 static const int ColsAtCompileTime = Traits::ColsAtCompileTime;
531 static const int Options = PlainObject::Options;
532 typedef std::complex<typename NumTraits<Scalar>::Real> ComplexScalar;
533 typedef Matrix<ComplexScalar, Dynamic, Dynamic, Options, RowsAtCompileTime, ColsAtCompileTime> DynMatrixType;
535 AtomicType atomic(m_f);
537 const PlainObject Aevaluated = m_A.eval();
542 Index rows()
const {
return m_A.rows(); }
543 Index cols()
const {
return m_A.cols(); }
546 typename internal::nested<Derived>::type m_A;
553 template<
typename Derived>
554 struct traits<MatrixFunctionReturnValue<Derived> >
556 typedef typename Derived::PlainObject ReturnType;
564 template <
typename Derived>
565 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::matrixFunction(
typename internal::stem_function<
typename internal::traits<Derived>::Scalar>::type f)
const
567 eigen_assert(rows() == cols());
568 return MatrixFunctionReturnValue<Derived>(derived(), f);
571 template <
typename Derived>
572 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::sin()
const
574 eigen_assert(rows() == cols());
575 typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
576 return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::sin);
579 template <
typename Derived>
580 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::cos()
const
582 eigen_assert(rows() == cols());
583 typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
584 return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::cos);
587 template <
typename Derived>
588 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::sinh()
const
590 eigen_assert(rows() == cols());
591 typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
592 return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::sinh);
595 template <
typename Derived>
596 const MatrixFunctionReturnValue<Derived> MatrixBase<Derived>::cosh()
const
598 eigen_assert(rows() == cols());
599 typedef typename internal::stem_function<Scalar>::ComplexScalar ComplexScalar;
600 return MatrixFunctionReturnValue<Derived>(derived(), StdStemFunctions<ComplexScalar>::cosh);
605 #endif // EIGEN_MATRIX_FUNCTION