25 #ifndef EIGEN_GENERAL_MATRIX_MATRIX_H
26 #define EIGEN_GENERAL_MATRIX_MATRIX_H
32 template<
typename _LhsScalar,
typename _RhsScalar>
class level3_blocking;
37 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
38 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs>
39 struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,
RowMajor>
41 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
43 Index rows, Index cols, Index depth,
44 const LhsScalar* lhs, Index lhsStride,
45 const RhsScalar* rhs, Index rhsStride,
46 ResScalar* res, Index resStride,
48 level3_blocking<RhsScalar,LhsScalar>& blocking,
49 GemmParallelInfo<Index>* info = 0)
52 general_matrix_matrix_product<Index,
56 ::run(cols,rows,depth,rhs,rhsStride,lhs,lhsStride,res,resStride,alpha,blocking,info);
64 typename LhsScalar,
int LhsStorageOrder,
bool ConjugateLhs,
65 typename RhsScalar,
int RhsStorageOrder,
bool ConjugateRhs>
66 struct general_matrix_matrix_product<Index,LhsScalar,LhsStorageOrder,ConjugateLhs,RhsScalar,RhsStorageOrder,ConjugateRhs,
ColMajor>
68 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
69 static void run(Index rows, Index cols, Index depth,
70 const LhsScalar* _lhs, Index lhsStride,
71 const RhsScalar* _rhs, Index rhsStride,
72 ResScalar* res, Index resStride,
74 level3_blocking<LhsScalar,RhsScalar>& blocking,
75 GemmParallelInfo<Index>* info = 0)
77 const_blas_data_mapper<LhsScalar, Index, LhsStorageOrder> lhs(_lhs,lhsStride);
78 const_blas_data_mapper<RhsScalar, Index, RhsStorageOrder> rhs(_rhs,rhsStride);
80 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
82 Index kc = blocking.kc();
83 Index mc = (std::min)(rows,blocking.mc());
86 gemm_pack_lhs<LhsScalar, Index, Traits::mr, Traits::LhsProgress, LhsStorageOrder> pack_lhs;
87 gemm_pack_rhs<RhsScalar, Index, Traits::nr, RhsStorageOrder> pack_rhs;
88 gebp_kernel<LhsScalar, RhsScalar, Index, Traits::mr, Traits::nr, ConjugateLhs, ConjugateRhs> gebp;
90 #ifdef EIGEN_HAS_OPENMP
94 Index tid = omp_get_thread_num();
95 Index threads = omp_get_num_threads();
97 std::size_t sizeA = kc*mc;
98 std::size_t sizeW = kc*Traits::WorkSpaceFactor;
102 RhsScalar* blockB = blocking.blockB();
106 for(Index k=0; k<depth; k+=kc)
108 const Index actual_kc = (std::min)(k+kc,depth)-k;
112 pack_lhs(blockA, &lhs(0,k), lhsStride, actual_kc, mc);
120 while(info[tid].users!=0) {}
121 info[tid].users += threads;
123 pack_rhs(blockB+info[tid].rhs_start*actual_kc, &rhs(k,info[tid].rhs_start), rhsStride, actual_kc, info[tid].rhs_length);
129 for(Index shift=0; shift<threads; ++shift)
131 Index j = (tid+shift)%threads;
137 while(info[j].sync!=k) {}
139 gebp(res+info[j].rhs_start*resStride, resStride, blockA, blockB+info[j].rhs_start*actual_kc, mc, actual_kc, info[j].rhs_length, alpha, -1,-1,0,0, w);
143 for(Index i=mc; i<rows; i+=mc)
145 const Index actual_mc = (std::min)(i+mc,rows)-i;
148 pack_lhs(blockA, &lhs(i,k), lhsStride, actual_kc, actual_mc);
151 gebp(res+i, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1,-1,0,0, w);
156 for(Index j=0; j<threads; ++j)
162 #endif // EIGEN_HAS_OPENMP
167 std::size_t sizeA = kc*mc;
168 std::size_t sizeB = kc*cols;
169 std::size_t sizeW = kc*Traits::WorkSpaceFactor;
177 for(Index k2=0; k2<depth; k2+=kc)
179 const Index actual_kc = (std::min)(k2+kc,depth)-k2;
185 pack_rhs(blockB, &rhs(k2,0), rhsStride, actual_kc, cols);
190 for(Index i2=0; i2<rows; i2+=mc)
192 const Index actual_mc = (std::min)(i2+mc,rows)-i2;
197 pack_lhs(blockA, &lhs(i2,k2), lhsStride, actual_kc, actual_mc);
200 gebp(res+i2, resStride, blockA, blockB, actual_mc, actual_kc, cols, alpha, -1, -1, 0, 0, blockW);
214 template<
typename Lhs,
typename Rhs>
216 :
traits<ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs> >
219 template<
typename Scalar,
typename Index,
typename Gemm,
typename Lhs,
typename Rhs,
typename Dest,
typename BlockingType>
222 gemm_functor(
const Lhs& lhs,
const Rhs& rhs, Dest& dest, Scalar actualAlpha,
223 BlockingType& blocking)
224 : m_lhs(lhs), m_rhs(rhs), m_dest(dest), m_actualAlpha(actualAlpha), m_blocking(blocking)
227 void initParallelSession()
const
229 m_blocking.allocateB();
232 void operator() (Index
row, Index rows, Index
col=0, Index cols=-1, GemmParallelInfo<Index>* info=0)
const
237 Gemm::run(rows, cols, m_lhs.cols(),
238 &m_lhs.coeffRef(row,0), m_lhs.outerStride(),
239 &m_rhs.coeffRef(0,
col), m_rhs.outerStride(),
240 (Scalar*)&(m_dest.coeffRef(row,
col)), m_dest.outerStride(),
241 m_actualAlpha, m_blocking, info);
248 Scalar m_actualAlpha;
249 BlockingType& m_blocking;
252 template<
int StorageOrder,
typename LhsScalar,
typename RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor=1,
253 bool FiniteAtCompileTime = MaxRows!=
Dynamic && MaxCols!=
Dynamic && MaxDepth !=
Dynamic>
class gemm_blocking_space;
255 template<
typename _LhsScalar,
typename _RhsScalar>
256 class level3_blocking
258 typedef _LhsScalar LhsScalar;
259 typedef _RhsScalar RhsScalar;
273 : m_blockA(0), m_blockB(0), m_blockW(0), m_mc(0), m_nc(0), m_kc(0)
280 inline LhsScalar* blockA() {
return m_blockA; }
281 inline RhsScalar* blockB() {
return m_blockB; }
282 inline RhsScalar* blockW() {
return m_blockW; }
285 template<
int StorageOrder,
typename _LhsScalar,
typename _RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor>
286 class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, true>
287 :
public level3_blocking<
288 typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
289 typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
293 ActualRows = Transpose ? MaxCols : MaxRows,
294 ActualCols = Transpose ? MaxRows : MaxCols
296 typedef typename conditional<Transpose,_RhsScalar,_LhsScalar>::type LhsScalar;
297 typedef typename conditional<Transpose,_LhsScalar,_RhsScalar>::type RhsScalar;
298 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
300 SizeA = ActualRows * MaxDepth,
301 SizeB = ActualCols * MaxDepth,
302 SizeW = MaxDepth * Traits::WorkSpaceFactor
313 this->m_mc = ActualRows;
314 this->m_nc = ActualCols;
315 this->m_kc = MaxDepth;
316 this->m_blockA = m_staticA;
317 this->m_blockB = m_staticB;
318 this->m_blockW = m_staticW;
321 inline void allocateA() {}
322 inline void allocateB() {}
323 inline void allocateW() {}
324 inline void allocateAll() {}
327 template<
int StorageOrder,
typename _LhsScalar,
typename _RhsScalar,
int MaxRows,
int MaxCols,
int MaxDepth,
int KcFactor>
328 class gemm_blocking_space<StorageOrder,_LhsScalar,_RhsScalar,MaxRows, MaxCols, MaxDepth, KcFactor, false>
329 :
public level3_blocking<
330 typename conditional<StorageOrder==RowMajor,_RhsScalar,_LhsScalar>::type,
331 typename conditional<StorageOrder==RowMajor,_LhsScalar,_RhsScalar>::type>
336 typedef typename conditional<Transpose,_RhsScalar,_LhsScalar>::type LhsScalar;
337 typedef typename conditional<Transpose,_LhsScalar,_RhsScalar>::type RhsScalar;
338 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
348 this->m_mc = Transpose ? cols : rows;
349 this->m_nc = Transpose ? rows : cols;
352 computeProductBlockingSizes<LhsScalar,RhsScalar,KcFactor>(this->m_kc, this->m_mc, this->m_nc);
353 m_sizeA = this->m_mc * this->m_kc;
354 m_sizeB = this->m_kc * this->m_nc;
355 m_sizeW = this->m_kc*Traits::WorkSpaceFactor;
360 if(this->m_blockA==0)
361 this->m_blockA = aligned_new<LhsScalar>(m_sizeA);
366 if(this->m_blockB==0)
367 this->m_blockB = aligned_new<RhsScalar>(m_sizeB);
372 if(this->m_blockW==0)
373 this->m_blockW = aligned_new<RhsScalar>(m_sizeW);
383 ~gemm_blocking_space()
393 template<
typename Lhs,
typename Rhs>
395 :
public ProductBase<GeneralProduct<Lhs,Rhs,GemmProduct>, Lhs, Rhs>
409 typedef internal::scalar_product_op<LhsScalar,RhsScalar> BinOp;
413 template<
typename Dest>
void scaleAndAddTo(Dest& dst,
Scalar alpha)
const
415 eigen_assert(dst.rows()==m_lhs.rows() && dst.cols()==m_rhs.cols());
417 typename internal::add_const_on_value_type<ActualLhsType>::type lhs = LhsBlasTraits::extract(m_lhs);
418 typename internal::add_const_on_value_type<ActualRhsType>::type rhs = RhsBlasTraits::extract(m_rhs);
420 Scalar actualAlpha = alpha * LhsBlasTraits::extractScalarFactor(m_lhs)
421 * RhsBlasTraits::extractScalarFactor(m_rhs);
424 Dest::MaxRowsAtCompileTime,Dest::MaxColsAtCompileTime,MaxDepthAtCompileTime> BlockingType;
426 typedef internal::gemm_functor<
428 internal::general_matrix_matrix_product<
435 BlockingType blocking(dst.rows(), dst.cols(), lhs.cols());
437 internal::parallelize_gemm<(Dest::MaxRowsAtCompileTime>32 || Dest::MaxRowsAtCompileTime==
Dynamic)>(GemmFunctor(lhs, rhs, dst, actualAlpha, blocking), this->rows(), this->cols(), Dest::Flags&
RowMajorBit);
443 #endif // EIGEN_GENERAL_MATRIX_MATRIX_H