25 #ifndef EIGEN_GENERAL_BLOCK_PANEL_H
26 #define EIGEN_GENERAL_BLOCK_PANEL_H
32 template<
typename _LhsScalar,
typename _RhsScalar,
bool _ConjLhs=false,
bool _ConjRhs=false>
45 static std::ptrdiff_t m_l1CacheSize = 0;
46 static std::ptrdiff_t m_l2CacheSize = 0;
87 template<
typename LhsScalar,
typename RhsScalar,
int KcFactor>
98 std::ptrdiff_t l1, l2;
100 typedef gebp_traits<LhsScalar,RhsScalar> Traits;
102 kdiv = KcFactor * 2 * Traits::nr
103 * Traits::RhsProgress *
sizeof(RhsScalar),
104 mr = gebp_traits<LhsScalar,RhsScalar>::mr,
105 mr_mask = (0xffffffff/mr)*mr
109 k = std::min<std::ptrdiff_t>(k, l1/kdiv);
110 std::ptrdiff_t _m = k>0 ? l2/(4 *
sizeof(LhsScalar) * k) : 0;
111 if(_m<m) m = _m & mr_mask;
114 template<
typename LhsScalar,
typename RhsScalar>
117 computeProductBlockingSizes<LhsScalar,RhsScalar,1>(k, m, n);
120 #ifdef EIGEN_HAS_FUSE_CJMADD
121 #define MADD(CJ,A,B,C,T) C = CJ.pmadd(A,B,C);
126 template<
typename CJ,
typename A,
typename B,
typename C,
typename T>
struct gebp_madd_selector {
133 template<
typename CJ,
typename T>
struct gebp_madd_selector<CJ,T,T,T,T> {
136 t = b; t = cj.pmul(a,t); c =
padd(c,t);
140 template<
typename CJ,
typename A,
typename B,
typename C,
typename T>
143 gebp_madd_selector<CJ,A,B,C,T>::run(cj,a,b,c,t);
146 #define MADD(CJ,A,B,C,T) gebp_madd(CJ,A,B,C,T);
160 template<
typename _LhsScalar,
typename _RhsScalar,
bool _ConjLhs,
bool _ConjRhs>
164 typedef _LhsScalar LhsScalar;
165 typedef _RhsScalar RhsScalar;
166 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
171 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable,
172 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
173 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
174 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
179 nr = NumberOfRegisters/4,
182 mr = 2 * LhsPacketSize,
184 WorkSpaceFactor = nr * RhsPacketSize,
186 LhsProgress = LhsPacketSize,
187 RhsProgress = RhsPacketSize
190 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
191 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
192 typedef typename packet_traits<ResScalar>::type _ResPacket;
194 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
195 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
196 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
198 typedef ResPacket AccPacket;
202 p = pset1<ResPacket>(ResScalar(0));
208 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]);
213 dest = pload<RhsPacket>(b);
218 dest = pload<LhsPacket>(a);
221 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, AccPacket& tmp)
const
223 tmp = b; tmp =
pmul(a,tmp); c =
padd(c,tmp);
226 EIGEN_STRONG_INLINE void acc(
const AccPacket& c,
const ResPacket& alpha, ResPacket& r)
const
228 r =
pmadd(c,alpha,r);
236 template<
typename RealScalar,
bool _ConjLhs>
237 class gebp_traits<std::complex<RealScalar>, RealScalar, _ConjLhs, false>
240 typedef std::complex<RealScalar> LhsScalar;
241 typedef RealScalar RhsScalar;
242 typedef typename scalar_product_traits<LhsScalar, RhsScalar>::ReturnType ResScalar;
247 Vectorizable = packet_traits<LhsScalar>::Vectorizable && packet_traits<RhsScalar>::Vectorizable,
248 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
249 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
250 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
253 nr = NumberOfRegisters/4,
254 mr = 2 * LhsPacketSize,
255 WorkSpaceFactor = nr*RhsPacketSize,
257 LhsProgress = LhsPacketSize,
258 RhsProgress = RhsPacketSize
261 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
262 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
263 typedef typename packet_traits<ResScalar>::type _ResPacket;
265 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
266 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
267 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
269 typedef ResPacket AccPacket;
273 p = pset1<ResPacket>(ResScalar(0));
279 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]);
284 dest = pload<RhsPacket>(b);
289 dest = pload<LhsPacket>(a);
292 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, RhsPacket& tmp)
const
294 madd_impl(a, b, c, tmp,
typename conditional<Vectorizable,true_type,false_type>::type());
297 EIGEN_STRONG_INLINE void madd_impl(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, RhsPacket& tmp,
const true_type&)
const
299 tmp = b; tmp =
pmul(a.v,tmp); c.v =
padd(c.v,tmp);
302 EIGEN_STRONG_INLINE void madd_impl(
const LhsScalar& a,
const RhsScalar& b, ResScalar& c, RhsScalar& ,
const false_type&)
const
307 EIGEN_STRONG_INLINE void acc(
const AccPacket& c,
const ResPacket& alpha, ResPacket& r)
const
309 r = cj.pmadd(c,alpha,r);
313 conj_helper<ResPacket,ResPacket,ConjLhs,false> cj;
316 template<
typename RealScalar,
bool _ConjLhs,
bool _ConjRhs>
317 class gebp_traits<std::complex<RealScalar>, std::complex<RealScalar>, _ConjLhs, _ConjRhs >
320 typedef std::complex<RealScalar> Scalar;
321 typedef std::complex<RealScalar> LhsScalar;
322 typedef std::complex<RealScalar> RhsScalar;
323 typedef std::complex<RealScalar> ResScalar;
328 Vectorizable = packet_traits<RealScalar>::Vectorizable
329 && packet_traits<Scalar>::Vectorizable,
330 RealPacketSize = Vectorizable ? packet_traits<RealScalar>::size : 1,
331 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
334 mr = 2 * ResPacketSize,
335 WorkSpaceFactor = Vectorizable ? 2*nr*RealPacketSize : nr,
337 LhsProgress = ResPacketSize,
338 RhsProgress = Vectorizable ? 2*ResPacketSize : 1
341 typedef typename packet_traits<RealScalar>::type RealPacket;
342 typedef typename packet_traits<Scalar>::type ScalarPacket;
349 typedef typename conditional<Vectorizable,RealPacket, Scalar>::type LhsPacket;
350 typedef typename conditional<Vectorizable,DoublePacket,Scalar>::type RhsPacket;
351 typedef typename conditional<Vectorizable,ScalarPacket,Scalar>::type ResPacket;
352 typedef typename conditional<Vectorizable,DoublePacket,Scalar>::type AccPacket;
358 p.first = pset1<RealPacket>(RealScalar(0));
359 p.second = pset1<RealPacket>(RealScalar(0));
372 pstore1<RealPacket>((RealScalar*)&b[k*ResPacketSize*2+0],
real(rhs[k]));
373 pstore1<RealPacket>((RealScalar*)&b[k*ResPacketSize*2+ResPacketSize],
imag(rhs[k]));
384 dest.first = pload<RealPacket>((
const RealScalar*)b);
385 dest.second = pload<RealPacket>((
const RealScalar*)(b+ResPacketSize));
391 dest = pload<LhsPacket>((
const typename unpacket_traits<LhsPacket>::type*)(a));
394 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, DoublePacket& c, RhsPacket& )
const
396 c.first =
padd(
pmul(a,b.first), c.first);
397 c.second =
padd(
pmul(a,b.second),c.second);
400 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, ResPacket& c, RhsPacket& )
const
405 EIGEN_STRONG_INLINE void acc(
const Scalar& c,
const Scalar& alpha, Scalar& r)
const { r += alpha * c; }
407 EIGEN_STRONG_INLINE void acc(
const DoublePacket& c,
const ResPacket& alpha, ResPacket& r)
const
411 if((!ConjLhs)&&(!ConjRhs))
414 tmp =
padd(ResPacket(c.first),tmp);
416 else if((!ConjLhs)&&(ConjRhs))
419 tmp =
padd(ResPacket(c.first),tmp);
421 else if((ConjLhs)&&(!ConjRhs))
424 tmp =
padd(
pconj(ResPacket(c.first)),tmp);
426 else if((ConjLhs)&&(ConjRhs))
429 tmp =
psub(
pconj(ResPacket(c.first)),tmp);
432 r =
pmadd(tmp,alpha,r);
436 conj_helper<LhsScalar,RhsScalar,ConjLhs,ConjRhs> cj;
439 template<
typename RealScalar,
bool _ConjRhs>
440 class gebp_traits<RealScalar, std::complex<RealScalar>, false, _ConjRhs >
443 typedef std::complex<RealScalar> Scalar;
444 typedef RealScalar LhsScalar;
445 typedef Scalar RhsScalar;
446 typedef Scalar ResScalar;
451 Vectorizable = packet_traits<RealScalar>::Vectorizable
452 && packet_traits<Scalar>::Vectorizable,
453 LhsPacketSize = Vectorizable ? packet_traits<LhsScalar>::size : 1,
454 RhsPacketSize = Vectorizable ? packet_traits<RhsScalar>::size : 1,
455 ResPacketSize = Vectorizable ? packet_traits<ResScalar>::size : 1,
459 mr = 2*ResPacketSize,
460 WorkSpaceFactor = nr*RhsPacketSize,
462 LhsProgress = ResPacketSize,
463 RhsProgress = ResPacketSize
466 typedef typename packet_traits<LhsScalar>::type _LhsPacket;
467 typedef typename packet_traits<RhsScalar>::type _RhsPacket;
468 typedef typename packet_traits<ResScalar>::type _ResPacket;
470 typedef typename conditional<Vectorizable,_LhsPacket,LhsScalar>::type LhsPacket;
471 typedef typename conditional<Vectorizable,_RhsPacket,RhsScalar>::type RhsPacket;
472 typedef typename conditional<Vectorizable,_ResPacket,ResScalar>::type ResPacket;
474 typedef ResPacket AccPacket;
478 p = pset1<ResPacket>(ResScalar(0));
484 pstore1<RhsPacket>(&b[k*RhsPacketSize], rhs[k]);
489 dest = pload<RhsPacket>(b);
494 dest = ploaddup<LhsPacket>(a);
497 EIGEN_STRONG_INLINE void madd(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, RhsPacket& tmp)
const
499 madd_impl(a, b, c, tmp,
typename conditional<Vectorizable,true_type,false_type>::type());
502 EIGEN_STRONG_INLINE void madd_impl(
const LhsPacket& a,
const RhsPacket& b, AccPacket& c, RhsPacket& tmp,
const true_type&)
const
504 tmp = b; tmp.v =
pmul(a,tmp.v); c =
padd(c,tmp);
507 EIGEN_STRONG_INLINE void madd_impl(
const LhsScalar& a,
const RhsScalar& b, ResScalar& c, RhsScalar& ,
const false_type&)
const
512 EIGEN_STRONG_INLINE void acc(
const AccPacket& c,
const ResPacket& alpha, ResPacket& r)
const
514 r = cj.pmadd(alpha,c,r);
518 conj_helper<ResPacket,ResPacket,false,ConjRhs> cj;
528 template<
typename LhsScalar,
typename RhsScalar,
typename Index,
int mr,
int nr,
bool ConjugateLhs,
bool ConjugateRhs>
531 typedef gebp_traits<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> Traits;
532 typedef typename Traits::ResScalar ResScalar;
533 typedef typename Traits::LhsPacket LhsPacket;
534 typedef typename Traits::RhsPacket RhsPacket;
535 typedef typename Traits::ResPacket ResPacket;
536 typedef typename Traits::AccPacket AccPacket;
539 Vectorizable = Traits::Vectorizable,
540 LhsProgress = Traits::LhsProgress,
541 RhsProgress = Traits::RhsProgress,
542 ResPacketSize = Traits::ResPacketSize
546 void operator()(ResScalar* res, Index resStride,
const LhsScalar* blockA,
const RhsScalar* blockB, Index rows, Index depth, Index cols, ResScalar alpha,
547 Index strideA=-1, Index strideB=-1, Index offsetA=0, Index offsetB=0, RhsScalar* unpackedB = 0)
551 if(strideA==-1) strideA = depth;
552 if(strideB==-1) strideB = depth;
553 conj_helper<LhsScalar,RhsScalar,ConjugateLhs,ConjugateRhs> cj;
555 Index packet_cols = (cols/nr) * nr;
556 const Index peeled_mc = (rows/mr)*mr;
558 const Index peeled_mc2 = peeled_mc + (rows-peeled_mc >= LhsProgress ? LhsProgress : 0);
559 const Index peeled_kc = (depth/4)*4;
562 unpackedB =
const_cast<RhsScalar*
>(blockB - strideB * nr * RhsProgress);
565 for(Index j2=0; j2<packet_cols; j2+=nr)
567 traits.unpackRhs(depth*nr,&blockB[j2*strideB+offsetB*nr],unpackedB);
572 for(Index i=0; i<peeled_mc; i+=mr)
574 const LhsScalar* blA = &blockA[i*strideA+offsetA*mr];
578 AccPacket C0, C1, C2, C3, C4, C5, C6, C7;
581 if(nr==4) traits.initAcc(C2);
582 if(nr==4) traits.initAcc(C3);
585 if(nr==4) traits.initAcc(C6);
586 if(nr==4) traits.initAcc(C7);
588 ResScalar* r0 = &res[(j2+0)*resStride + i];
589 ResScalar* r1 = r0 + resStride;
590 ResScalar* r2 = r1 + resStride;
591 ResScalar* r3 = r2 + resStride;
601 const RhsScalar* blB = unpackedB;
602 for(Index k=0; k<peeled_kc; k+=4)
611 traits.loadLhs(&blA[0*LhsProgress], A0);
612 traits.loadLhs(&blA[1*LhsProgress], A1);
613 traits.loadRhs(&blB[0*RhsProgress], B_0);
614 traits.madd(A0,B_0,C0,T0);
615 traits.madd(A1,B_0,C4,B_0);
616 traits.loadRhs(&blB[1*RhsProgress], B_0);
617 traits.madd(A0,B_0,C1,T0);
618 traits.madd(A1,B_0,C5,B_0);
620 traits.loadLhs(&blA[2*LhsProgress], A0);
621 traits.loadLhs(&blA[3*LhsProgress], A1);
622 traits.loadRhs(&blB[2*RhsProgress], B_0);
623 traits.madd(A0,B_0,C0,T0);
624 traits.madd(A1,B_0,C4,B_0);
625 traits.loadRhs(&blB[3*RhsProgress], B_0);
626 traits.madd(A0,B_0,C1,T0);
627 traits.madd(A1,B_0,C5,B_0);
629 traits.loadLhs(&blA[4*LhsProgress], A0);
630 traits.loadLhs(&blA[5*LhsProgress], A1);
631 traits.loadRhs(&blB[4*RhsProgress], B_0);
632 traits.madd(A0,B_0,C0,T0);
633 traits.madd(A1,B_0,C4,B_0);
634 traits.loadRhs(&blB[5*RhsProgress], B_0);
635 traits.madd(A0,B_0,C1,T0);
636 traits.madd(A1,B_0,C5,B_0);
638 traits.loadLhs(&blA[6*LhsProgress], A0);
639 traits.loadLhs(&blA[7*LhsProgress], A1);
640 traits.loadRhs(&blB[6*RhsProgress], B_0);
641 traits.madd(A0,B_0,C0,T0);
642 traits.madd(A1,B_0,C4,B_0);
643 traits.loadRhs(&blB[7*RhsProgress], B_0);
644 traits.madd(A0,B_0,C1,T0);
645 traits.madd(A1,B_0,C5,B_0);
652 RhsPacket B_0, B1, B2, B3;
655 traits.loadLhs(&blA[0*LhsProgress], A0);
656 traits.loadLhs(&blA[1*LhsProgress], A1);
657 traits.loadRhs(&blB[0*RhsProgress], B_0);
658 traits.loadRhs(&blB[1*RhsProgress], B1);
660 traits.madd(A0,B_0,C0,T0);
661 traits.loadRhs(&blB[2*RhsProgress], B2);
662 traits.madd(A1,B_0,C4,B_0);
663 traits.loadRhs(&blB[3*RhsProgress], B3);
664 traits.loadRhs(&blB[4*RhsProgress], B_0);
665 traits.madd(A0,B1,C1,T0);
666 traits.madd(A1,B1,C5,B1);
667 traits.loadRhs(&blB[5*RhsProgress], B1);
668 traits.madd(A0,B2,C2,T0);
669 traits.madd(A1,B2,C6,B2);
670 traits.loadRhs(&blB[6*RhsProgress], B2);
671 traits.madd(A0,B3,C3,T0);
672 traits.loadLhs(&blA[2*LhsProgress], A0);
673 traits.madd(A1,B3,C7,B3);
674 traits.loadLhs(&blA[3*LhsProgress], A1);
675 traits.loadRhs(&blB[7*RhsProgress], B3);
676 traits.madd(A0,B_0,C0,T0);
677 traits.madd(A1,B_0,C4,B_0);
678 traits.loadRhs(&blB[8*RhsProgress], B_0);
679 traits.madd(A0,B1,C1,T0);
680 traits.madd(A1,B1,C5,B1);
681 traits.loadRhs(&blB[9*RhsProgress], B1);
682 traits.madd(A0,B2,C2,T0);
683 traits.madd(A1,B2,C6,B2);
684 traits.loadRhs(&blB[10*RhsProgress], B2);
685 traits.madd(A0,B3,C3,T0);
686 traits.loadLhs(&blA[4*LhsProgress], A0);
687 traits.madd(A1,B3,C7,B3);
688 traits.loadLhs(&blA[5*LhsProgress], A1);
689 traits.loadRhs(&blB[11*RhsProgress], B3);
691 traits.madd(A0,B_0,C0,T0);
692 traits.madd(A1,B_0,C4,B_0);
693 traits.loadRhs(&blB[12*RhsProgress], B_0);
694 traits.madd(A0,B1,C1,T0);
695 traits.madd(A1,B1,C5,B1);
696 traits.loadRhs(&blB[13*RhsProgress], B1);
697 traits.madd(A0,B2,C2,T0);
698 traits.madd(A1,B2,C6,B2);
699 traits.loadRhs(&blB[14*RhsProgress], B2);
700 traits.madd(A0,B3,C3,T0);
701 traits.loadLhs(&blA[6*LhsProgress], A0);
702 traits.madd(A1,B3,C7,B3);
703 traits.loadLhs(&blA[7*LhsProgress], A1);
704 traits.loadRhs(&blB[15*RhsProgress], B3);
705 traits.madd(A0,B_0,C0,T0);
706 traits.madd(A1,B_0,C4,B_0);
707 traits.madd(A0,B1,C1,T0);
708 traits.madd(A1,B1,C5,B1);
709 traits.madd(A0,B2,C2,T0);
710 traits.madd(A1,B2,C6,B2);
711 traits.madd(A0,B3,C3,T0);
712 traits.madd(A1,B3,C7,B3);
715 blB += 4*nr*RhsProgress;
719 for(Index k=peeled_kc; k<depth; k++)
727 traits.loadLhs(&blA[0*LhsProgress], A0);
728 traits.loadLhs(&blA[1*LhsProgress], A1);
729 traits.loadRhs(&blB[0*RhsProgress], B_0);
730 traits.madd(A0,B_0,C0,T0);
731 traits.madd(A1,B_0,C4,B_0);
732 traits.loadRhs(&blB[1*RhsProgress], B_0);
733 traits.madd(A0,B_0,C1,T0);
734 traits.madd(A1,B_0,C5,B_0);
739 RhsPacket B_0, B1, B2, B3;
742 traits.loadLhs(&blA[0*LhsProgress], A0);
743 traits.loadLhs(&blA[1*LhsProgress], A1);
744 traits.loadRhs(&blB[0*RhsProgress], B_0);
745 traits.loadRhs(&blB[1*RhsProgress], B1);
747 traits.madd(A0,B_0,C0,T0);
748 traits.loadRhs(&blB[2*RhsProgress], B2);
749 traits.madd(A1,B_0,C4,B_0);
750 traits.loadRhs(&blB[3*RhsProgress], B3);
751 traits.madd(A0,B1,C1,T0);
752 traits.madd(A1,B1,C5,B1);
753 traits.madd(A0,B2,C2,T0);
754 traits.madd(A1,B2,C6,B2);
755 traits.madd(A0,B3,C3,T0);
756 traits.madd(A1,B3,C7,B3);
759 blB += nr*RhsProgress;
765 ResPacket R0, R1, R2, R3, R4, R5, R6;
766 ResPacket alphav = pset1<ResPacket>(alpha);
768 R0 = ploadu<ResPacket>(r0);
769 R1 = ploadu<ResPacket>(r1);
770 R2 = ploadu<ResPacket>(r2);
771 R3 = ploadu<ResPacket>(r3);
772 R4 = ploadu<ResPacket>(r0 + ResPacketSize);
773 R5 = ploadu<ResPacket>(r1 + ResPacketSize);
774 R6 = ploadu<ResPacket>(r2 + ResPacketSize);
775 traits.acc(C0, alphav, R0);
777 R0 = ploadu<ResPacket>(r3 + ResPacketSize);
779 traits.acc(C1, alphav, R1);
780 traits.acc(C2, alphav, R2);
781 traits.acc(C3, alphav, R3);
782 traits.acc(C4, alphav, R4);
783 traits.acc(C5, alphav, R5);
784 traits.acc(C6, alphav, R6);
785 traits.acc(C7, alphav, R0);
790 pstoreu(r0 + ResPacketSize, R4);
791 pstoreu(r1 + ResPacketSize, R5);
792 pstoreu(r2 + ResPacketSize, R6);
793 pstoreu(r3 + ResPacketSize, R0);
797 ResPacket R0, R1, R4;
798 ResPacket alphav = pset1<ResPacket>(alpha);
800 R0 = ploadu<ResPacket>(r0);
801 R1 = ploadu<ResPacket>(r1);
802 R4 = ploadu<ResPacket>(r0 + ResPacketSize);
803 traits.acc(C0, alphav, R0);
805 R0 = ploadu<ResPacket>(r1 + ResPacketSize);
806 traits.acc(C1, alphav, R1);
807 traits.acc(C4, alphav, R4);
808 traits.acc(C5, alphav, R0);
810 pstoreu(r0 + ResPacketSize, R4);
811 pstoreu(r1 + ResPacketSize, R0);
816 if(rows-peeled_mc>=LhsProgress)
819 const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsProgress];
823 AccPacket C0, C1, C2, C3;
826 if(nr==4) traits.initAcc(C2);
827 if(nr==4) traits.initAcc(C3);
830 const RhsScalar* blB = unpackedB;
831 for(Index k=0; k<peeled_kc; k+=4)
838 traits.loadLhs(&blA[0*LhsProgress], A0);
839 traits.loadRhs(&blB[0*RhsProgress], B_0);
840 traits.loadRhs(&blB[1*RhsProgress], B1);
841 traits.madd(A0,B_0,C0,B_0);
842 traits.loadRhs(&blB[2*RhsProgress], B_0);
843 traits.madd(A0,B1,C1,B1);
844 traits.loadLhs(&blA[1*LhsProgress], A0);
845 traits.loadRhs(&blB[3*RhsProgress], B1);
846 traits.madd(A0,B_0,C0,B_0);
847 traits.loadRhs(&blB[4*RhsProgress], B_0);
848 traits.madd(A0,B1,C1,B1);
849 traits.loadLhs(&blA[2*LhsProgress], A0);
850 traits.loadRhs(&blB[5*RhsProgress], B1);
851 traits.madd(A0,B_0,C0,B_0);
852 traits.loadRhs(&blB[6*RhsProgress], B_0);
853 traits.madd(A0,B1,C1,B1);
854 traits.loadLhs(&blA[3*LhsProgress], A0);
855 traits.loadRhs(&blB[7*RhsProgress], B1);
856 traits.madd(A0,B_0,C0,B_0);
857 traits.madd(A0,B1,C1,B1);
862 RhsPacket B_0, B1, B2, B3;
864 traits.loadLhs(&blA[0*LhsProgress], A0);
865 traits.loadRhs(&blB[0*RhsProgress], B_0);
866 traits.loadRhs(&blB[1*RhsProgress], B1);
868 traits.madd(A0,B_0,C0,B_0);
869 traits.loadRhs(&blB[2*RhsProgress], B2);
870 traits.loadRhs(&blB[3*RhsProgress], B3);
871 traits.loadRhs(&blB[4*RhsProgress], B_0);
872 traits.madd(A0,B1,C1,B1);
873 traits.loadRhs(&blB[5*RhsProgress], B1);
874 traits.madd(A0,B2,C2,B2);
875 traits.loadRhs(&blB[6*RhsProgress], B2);
876 traits.madd(A0,B3,C3,B3);
877 traits.loadLhs(&blA[1*LhsProgress], A0);
878 traits.loadRhs(&blB[7*RhsProgress], B3);
879 traits.madd(A0,B_0,C0,B_0);
880 traits.loadRhs(&blB[8*RhsProgress], B_0);
881 traits.madd(A0,B1,C1,B1);
882 traits.loadRhs(&blB[9*RhsProgress], B1);
883 traits.madd(A0,B2,C2,B2);
884 traits.loadRhs(&blB[10*RhsProgress], B2);
885 traits.madd(A0,B3,C3,B3);
886 traits.loadLhs(&blA[2*LhsProgress], A0);
887 traits.loadRhs(&blB[11*RhsProgress], B3);
889 traits.madd(A0,B_0,C0,B_0);
890 traits.loadRhs(&blB[12*RhsProgress], B_0);
891 traits.madd(A0,B1,C1,B1);
892 traits.loadRhs(&blB[13*RhsProgress], B1);
893 traits.madd(A0,B2,C2,B2);
894 traits.loadRhs(&blB[14*RhsProgress], B2);
895 traits.madd(A0,B3,C3,B3);
897 traits.loadLhs(&blA[3*LhsProgress], A0);
898 traits.loadRhs(&blB[15*RhsProgress], B3);
899 traits.madd(A0,B_0,C0,B_0);
900 traits.madd(A0,B1,C1,B1);
901 traits.madd(A0,B2,C2,B2);
902 traits.madd(A0,B3,C3,B3);
905 blB += nr*4*RhsProgress;
906 blA += 4*LhsProgress;
909 for(Index k=peeled_kc; k<depth; k++)
916 traits.loadLhs(&blA[0*LhsProgress], A0);
917 traits.loadRhs(&blB[0*RhsProgress], B_0);
918 traits.loadRhs(&blB[1*RhsProgress], B1);
919 traits.madd(A0,B_0,C0,B_0);
920 traits.madd(A0,B1,C1,B1);
925 RhsPacket B_0, B1, B2, B3;
927 traits.loadLhs(&blA[0*LhsProgress], A0);
928 traits.loadRhs(&blB[0*RhsProgress], B_0);
929 traits.loadRhs(&blB[1*RhsProgress], B1);
930 traits.loadRhs(&blB[2*RhsProgress], B2);
931 traits.loadRhs(&blB[3*RhsProgress], B3);
933 traits.madd(A0,B_0,C0,B_0);
934 traits.madd(A0,B1,C1,B1);
935 traits.madd(A0,B2,C2,B2);
936 traits.madd(A0,B3,C3,B3);
939 blB += nr*RhsProgress;
943 ResPacket R0, R1, R2, R3;
944 ResPacket alphav = pset1<ResPacket>(alpha);
946 ResScalar* r0 = &res[(j2+0)*resStride + i];
947 ResScalar* r1 = r0 + resStride;
948 ResScalar* r2 = r1 + resStride;
949 ResScalar* r3 = r2 + resStride;
951 R0 = ploadu<ResPacket>(r0);
952 R1 = ploadu<ResPacket>(r1);
953 if(nr==4) R2 = ploadu<ResPacket>(r2);
954 if(nr==4) R3 = ploadu<ResPacket>(r3);
956 traits.acc(C0, alphav, R0);
957 traits.acc(C1, alphav, R1);
958 if(nr==4) traits.acc(C2, alphav, R2);
959 if(nr==4) traits.acc(C3, alphav, R3);
966 for(Index i=peeled_mc2; i<rows; i++)
968 const LhsScalar* blA = &blockA[i*strideA+offsetA];
972 ResScalar C0(0), C1(0), C2(0), C3(0);
974 const RhsScalar* blB = &blockB[j2*strideB+offsetB*nr];
975 for(Index k=0; k<depth; k++)
985 MADD(cj,A0,B_0,C0,B_0);
986 MADD(cj,A0,B1,C1,B1);
991 RhsScalar B_0, B1, B2, B3;
999 MADD(cj,A0,B_0,C0,B_0);
1000 MADD(cj,A0,B1,C1,B1);
1001 MADD(cj,A0,B2,C2,B2);
1002 MADD(cj,A0,B3,C3,B3);
1007 res[(j2+0)*resStride + i] += alpha*C0;
1008 res[(j2+1)*resStride + i] += alpha*C1;
1009 if(nr==4) res[(j2+2)*resStride + i] += alpha*C2;
1010 if(nr==4) res[(j2+3)*resStride + i] += alpha*C3;
1015 for(Index j2=packet_cols; j2<cols; j2++)
1018 traits.unpackRhs(depth, &blockB[j2*strideB+offsetB], unpackedB);
1020 for(Index i=0; i<peeled_mc; i+=mr)
1022 const LhsScalar* blA = &blockA[i*strideA+offsetA*mr];
1032 const RhsScalar* blB = unpackedB;
1033 for(Index k=0; k<depth; k++)
1039 traits.loadLhs(&blA[0*LhsProgress], A0);
1040 traits.loadLhs(&blA[1*LhsProgress], A1);
1041 traits.loadRhs(&blB[0*RhsProgress], B_0);
1042 traits.madd(A0,B_0,C0,T0);
1043 traits.madd(A1,B_0,C4,B_0);
1046 blA += 2*LhsProgress;
1049 ResPacket alphav = pset1<ResPacket>(alpha);
1051 ResScalar* r0 = &res[(j2+0)*resStride + i];
1053 R0 = ploadu<ResPacket>(r0);
1054 R4 = ploadu<ResPacket>(r0+ResPacketSize);
1056 traits.acc(C0, alphav, R0);
1057 traits.acc(C4, alphav, R4);
1060 pstoreu(r0+ResPacketSize, R4);
1062 if(rows-peeled_mc>=LhsProgress)
1064 Index i = peeled_mc;
1065 const LhsScalar* blA = &blockA[i*strideA+offsetA*LhsProgress];
1071 const RhsScalar* blB = unpackedB;
1072 for(Index k=0; k<depth; k++)
1076 traits.loadLhs(blA, A0);
1077 traits.loadRhs(blB, B_0);
1078 traits.madd(A0, B_0, C0, B_0);
1083 ResPacket alphav = pset1<ResPacket>(alpha);
1084 ResPacket R0 = ploadu<ResPacket>(&res[(j2+0)*resStride + i]);
1085 traits.acc(C0, alphav, R0);
1086 pstoreu(&res[(j2+0)*resStride + i], R0);
1088 for(Index i=peeled_mc2; i<rows; i++)
1090 const LhsScalar* blA = &blockA[i*strideA+offsetA];
1096 const RhsScalar* blB = &blockB[j2*strideB+offsetB];
1097 for(Index k=0; k<depth; k++)
1099 LhsScalar A0 = blA[k];
1100 RhsScalar B_0 = blB[k];
1101 MADD(cj, A0, B_0, C0, B_0);
1103 res[(j2+0)*resStride + i] += alpha*C0;
1125 template<
typename Scalar,
typename Index,
int Pack1,
int Pack2,
int StorageOrder,
bool Conjugate,
bool PanelMode>
1126 struct gemm_pack_lhs
1129 Index stride=0, Index offset=0)
1131 typedef typename packet_traits<Scalar>::type Packet;
1132 enum { PacketSize = packet_traits<Scalar>::size };
1135 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
1138 const_blas_data_mapper<Scalar, Index, StorageOrder> lhs(_lhs,lhsStride);
1140 Index peeled_mc = (rows/Pack1)*Pack1;
1141 for(Index i=0; i<peeled_mc; i+=Pack1)
1143 if(PanelMode) count += Pack1 * offset;
1147 for(Index k=0; k<depth; k++)
1150 if(Pack1>=1*PacketSize) A = ploadu<Packet>(&lhs(i+0*PacketSize, k));
1151 if(Pack1>=2*PacketSize) B = ploadu<Packet>(&lhs(i+1*PacketSize, k));
1152 if(Pack1>=3*PacketSize) C = ploadu<Packet>(&lhs(i+2*PacketSize, k));
1153 if(Pack1>=4*PacketSize) D = ploadu<Packet>(&lhs(i+3*PacketSize, k));
1154 if(Pack1>=1*PacketSize) {
pstore(blockA+count, cj.pconj(A)); count+=PacketSize; }
1155 if(Pack1>=2*PacketSize) {
pstore(blockA+count, cj.pconj(B)); count+=PacketSize; }
1156 if(Pack1>=3*PacketSize) {
pstore(blockA+count, cj.pconj(C)); count+=PacketSize; }
1157 if(Pack1>=4*PacketSize) {
pstore(blockA+count, cj.pconj(D)); count+=PacketSize; }
1162 for(Index k=0; k<depth; k++)
1166 for(; w<Pack1-3; w+=4)
1168 Scalar a(cj(lhs(i+w+0, k))),
1169 b(cj(lhs(i+w+1, k))),
1170 c(cj(lhs(i+w+2, k))),
1171 d(cj(lhs(i+w+3, k)));
1172 blockA[count++] = a;
1173 blockA[count++] = b;
1174 blockA[count++] = c;
1175 blockA[count++] = d;
1179 blockA[count++] = cj(lhs(i+w, k));
1182 if(PanelMode) count += Pack1 * (stride-offset-depth);
1184 if(rows-peeled_mc>=Pack2)
1186 if(PanelMode) count += Pack2*offset;
1187 for(Index k=0; k<depth; k++)
1188 for(Index w=0; w<Pack2; w++)
1189 blockA[count++] = cj(lhs(peeled_mc+w, k));
1190 if(PanelMode) count += Pack2 * (stride-offset-depth);
1193 for(Index i=peeled_mc; i<rows; i++)
1195 if(PanelMode) count += offset;
1196 for(Index k=0; k<depth; k++)
1197 blockA[count++] = cj(lhs(i, k));
1198 if(PanelMode) count += (stride-offset-depth);
1210 template<
typename Scalar,
typename Index,
int nr,
bool Conjugate,
bool PanelMode>
1211 struct gemm_pack_rhs<Scalar, Index, nr,
ColMajor, Conjugate, PanelMode>
1213 typedef typename packet_traits<Scalar>::type Packet;
1214 enum { PacketSize = packet_traits<Scalar>::size };
1215 EIGEN_DONT_INLINE void operator()(Scalar* blockB,
const Scalar* rhs, Index rhsStride, Index depth, Index cols,
1216 Index stride=0, Index offset=0)
1219 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
1221 Index packet_cols = (cols/nr) * nr;
1223 for(Index j2=0; j2<packet_cols; j2+=nr)
1226 if(PanelMode) count += nr * offset;
1227 const Scalar* b0 = &rhs[(j2+0)*rhsStride];
1228 const Scalar* b1 = &rhs[(j2+1)*rhsStride];
1229 const Scalar* b2 = &rhs[(j2+2)*rhsStride];
1230 const Scalar* b3 = &rhs[(j2+3)*rhsStride];
1231 for(Index k=0; k<depth; k++)
1233 blockB[count+0] = cj(b0[k]);
1234 blockB[count+1] = cj(b1[k]);
1235 if(nr==4) blockB[count+2] = cj(b2[k]);
1236 if(nr==4) blockB[count+3] = cj(b3[k]);
1240 if(PanelMode) count += nr * (stride-offset-depth);
1244 for(Index j2=packet_cols; j2<cols; ++j2)
1246 if(PanelMode) count += offset;
1247 const Scalar* b0 = &rhs[(j2+0)*rhsStride];
1248 for(Index k=0; k<depth; k++)
1250 blockB[count] = cj(b0[k]);
1253 if(PanelMode) count += (stride-offset-depth);
1259 template<
typename Scalar,
typename Index,
int nr,
bool Conjugate,
bool PanelMode>
1260 struct gemm_pack_rhs<Scalar, Index, nr,
RowMajor, Conjugate, PanelMode>
1262 enum { PacketSize = packet_traits<Scalar>::size };
1263 EIGEN_DONT_INLINE void operator()(Scalar* blockB,
const Scalar* rhs, Index rhsStride, Index depth, Index cols,
1264 Index stride=0, Index offset=0)
1267 eigen_assert(((!PanelMode) && stride==0 && offset==0) || (PanelMode && stride>=depth && offset<=stride));
1269 Index packet_cols = (cols/nr) * nr;
1271 for(Index j2=0; j2<packet_cols; j2+=nr)
1274 if(PanelMode) count += nr * offset;
1275 for(Index k=0; k<depth; k++)
1277 const Scalar* b0 = &rhs[k*rhsStride + j2];
1278 blockB[count+0] = cj(b0[0]);
1279 blockB[count+1] = cj(b0[1]);
1280 if(nr==4) blockB[count+2] = cj(b0[2]);
1281 if(nr==4) blockB[count+3] = cj(b0[3]);
1285 if(PanelMode) count += nr * (stride-offset-depth);
1288 for(Index j2=packet_cols; j2<cols; ++j2)
1290 if(PanelMode) count += offset;
1291 const Scalar* b0 = &rhs[j2];
1292 for(Index k=0; k<depth; k++)
1294 blockB[count] = cj(b0[k*rhsStride]);
1297 if(PanelMode) count += stride-offset-depth;
1308 std::ptrdiff_t l1, l2;
1317 std::ptrdiff_t l1, l2;
1334 #endif // EIGEN_GENERAL_BLOCK_PANEL_H