diff --git a/Eigen/src/Core/arch/NEON/Kernels.h b/Eigen/src/Core/arch/NEON/Kernels.h index 4411389e5..6f9273274 100644 --- a/Eigen/src/Core/arch/NEON/Kernels.h +++ b/Eigen/src/Core/arch/NEON/Kernels.h @@ -130,14 +130,15 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 12, 1> _acc.packet[2] *= pAlpha; } - EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket_& pAlpha) { PacketBlock block; - block.packet[0] = dest.template loadPacket(row + 0, col) + _acc.packet[0]; + block.packet[0] = dest.template loadPacket(row + 0, col) + pAlpha*_acc.packet[0]; dest.template storePacketBlock(row + 0, col, block); - block.packet[0] = dest.template loadPacket(row + 4, col) + _acc.packet[1]; + block.packet[0] = dest.template loadPacket(row + 4, col) + pAlpha*_acc.packet[1]; dest.template storePacketBlock(row + 4, col, block); - block.packet[0] = dest.template loadPacket(row + 8, col) + _acc.packet[2]; + block.packet[0] = dest.template loadPacket(row + 8, col) + pAlpha*_acc.packet[2]; dest.template storePacketBlock(row + 8, col, block); } }; @@ -166,12 +167,13 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 8, 1> _acc.packet[1] *= pAlpha; } - EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket_& pAlpha) { PacketBlock block; - block.packet[0] = dest.template loadPacket(row, col) + _acc.packet[0]; + block.packet[0] = dest.template loadPacket(row, col) + pAlpha*_acc.packet[0]; dest.template storePacketBlock(row, col, block); - block.packet[0] = dest.template loadPacket(row + 4, col) + _acc.packet[1]; + block.packet[0] = dest.template loadPacket(row + 4, col) + pAlpha*_acc.packet[1]; dest.template storePacketBlock(row + 4, col, block); } }; @@ -198,10 +200,11 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 1> _acc *= pAlpha; } - EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket_& pAlpha) { PacketBlock block; - block.packet[0] = dest.template loadPacket(row, col) + _acc; + block.packet[0] = dest.template loadPacket(row, col) + pAlpha*_acc; dest.template storePacketBlock(row, col, block); } }; @@ -228,9 +231,10 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 1, 4> _acc *= pAlpha; } - EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket_& pAlpha) { - ResPacket r = dest.template gatherPacket(row, col) + _acc; + ResPacket r = dest.template gatherPacket(row, col) + pAlpha*_acc; dest.template scatterPacket(row, col, r); } }; @@ -269,7 +273,8 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 4> _acc.packet[3] *= pAlpha; } - EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket_& pAlpha) { constexpr auto PacketSize = unpacket_traits::size; @@ -278,10 +283,10 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 4, 4> LinearMapper r2 = dest.getLinearMapper(row, col + 2); LinearMapper r3 = dest.getLinearMapper(row, col + 3); - r0.storePacket(0*PacketSize, r0.template loadPacket(0*PacketSize) + _acc.packet[0]); - r1.storePacket(0*PacketSize, r1.template loadPacket(0*PacketSize) + _acc.packet[1]); - r2.storePacket(0*PacketSize, r2.template loadPacket(0*PacketSize) + _acc.packet[2]); - r3.storePacket(0*PacketSize, r3.template loadPacket(0*PacketSize) + _acc.packet[3]); + r0.storePacket(0*PacketSize, r0.template loadPacket(0*PacketSize) + pAlpha*_acc.packet[0]); + r1.storePacket(0*PacketSize, r1.template loadPacket(0*PacketSize) + pAlpha*_acc.packet[1]); + r2.storePacket(0*PacketSize, r2.template loadPacket(0*PacketSize) + pAlpha*_acc.packet[2]); + r3.storePacket(0*PacketSize, r3.template loadPacket(0*PacketSize) + pAlpha*_acc.packet[3]); } }; @@ -330,7 +335,8 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 8, 4> _acc2.packet[3] *= pAlpha; } - EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket_& pAlpha) { constexpr auto PacketSize = unpacket_traits::size; @@ -339,15 +345,15 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 8, 4> LinearMapper r2 = dest.getLinearMapper(row, col + 2); LinearMapper r3 = dest.getLinearMapper(row, col + 3); - r0.storePacket(0*PacketSize, r0.template loadPacket(0*PacketSize) + _acc1.packet[0]); - r1.storePacket(0*PacketSize, r1.template loadPacket(0*PacketSize) + _acc1.packet[1]); - r2.storePacket(0*PacketSize, r2.template loadPacket(0*PacketSize) + _acc1.packet[2]); - r3.storePacket(0*PacketSize, r3.template loadPacket(0*PacketSize) + _acc1.packet[3]); + r0.storePacket(0*PacketSize, r0.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[0]); + r1.storePacket(0*PacketSize, r1.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[1]); + r2.storePacket(0*PacketSize, r2.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[2]); + r3.storePacket(0*PacketSize, r3.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[3]); - r0.storePacket(1*PacketSize, r0.template loadPacket(1*PacketSize) + _acc2.packet[0]); - r1.storePacket(1*PacketSize, r1.template loadPacket(1*PacketSize) + _acc2.packet[1]); - r2.storePacket(1*PacketSize, r2.template loadPacket(1*PacketSize) + _acc2.packet[2]); - r3.storePacket(1*PacketSize, r3.template loadPacket(1*PacketSize) + _acc2.packet[3]); + r0.storePacket(1*PacketSize, r0.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[0]); + r1.storePacket(1*PacketSize, r1.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[1]); + r2.storePacket(1*PacketSize, r2.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[2]); + r3.storePacket(1*PacketSize, r3.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[3]); } }; @@ -407,7 +413,8 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 12, 4> _acc3.packet[3] *= pAlpha; } - EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket_& pAlpha) { constexpr auto PacketSize = unpacket_traits::size; @@ -416,20 +423,20 @@ struct Accumulator<0, CPU, Scalar, ResScalar, DataMapper, 12, 4> LinearMapper r2 = dest.getLinearMapper(row, col + 2); LinearMapper r3 = dest.getLinearMapper(row, col + 3); - r0.storePacket(0*PacketSize, r0.template loadPacket(0*PacketSize) + _acc1.packet[0]); - r1.storePacket(0*PacketSize, r1.template loadPacket(0*PacketSize) + _acc1.packet[1]); - r2.storePacket(0*PacketSize, r2.template loadPacket(0*PacketSize) + _acc1.packet[2]); - r3.storePacket(0*PacketSize, r3.template loadPacket(0*PacketSize) + _acc1.packet[3]); + r0.storePacket(0*PacketSize, r0.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[0]); + r1.storePacket(0*PacketSize, r1.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[1]); + r2.storePacket(0*PacketSize, r2.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[2]); + r3.storePacket(0*PacketSize, r3.template loadPacket(0*PacketSize) + pAlpha*_acc1.packet[3]); - r0.storePacket(1*PacketSize, r0.template loadPacket(1*PacketSize) + _acc2.packet[0]); - r1.storePacket(1*PacketSize, r1.template loadPacket(1*PacketSize) + _acc2.packet[1]); - r2.storePacket(1*PacketSize, r2.template loadPacket(1*PacketSize) + _acc2.packet[2]); - r3.storePacket(1*PacketSize, r3.template loadPacket(1*PacketSize) + _acc2.packet[3]); + r0.storePacket(1*PacketSize, r0.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[0]); + r1.storePacket(1*PacketSize, r1.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[1]); + r2.storePacket(1*PacketSize, r2.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[2]); + r3.storePacket(1*PacketSize, r3.template loadPacket(1*PacketSize) + pAlpha*_acc2.packet[3]); - r0.storePacket(2*PacketSize, r0.template loadPacket(2*PacketSize) + _acc3.packet[0]); - r1.storePacket(2*PacketSize, r1.template loadPacket(2*PacketSize) + _acc3.packet[1]); - r2.storePacket(2*PacketSize, r2.template loadPacket(2*PacketSize) + _acc3.packet[2]); - r3.storePacket(2*PacketSize, r3.template loadPacket(2*PacketSize) + _acc3.packet[3]); + r0.storePacket(2*PacketSize, r0.template loadPacket(2*PacketSize) + pAlpha*_acc3.packet[0]); + r1.storePacket(2*PacketSize, r1.template loadPacket(2*PacketSize) + pAlpha*_acc3.packet[1]); + r2.storePacket(2*PacketSize, r2.template loadPacket(2*PacketSize) + pAlpha*_acc3.packet[2]); + r3.storePacket(2*PacketSize, r3.template loadPacket(2*PacketSize) + pAlpha*_acc3.packet[3]); } }; diff --git a/Eigen/src/Core/arch/NEON/MatrixProduct.h b/Eigen/src/Core/arch/NEON/MatrixProduct.h index 177b7b25d..f68699120 100644 --- a/Eigen/src/Core/arch/NEON/MatrixProduct.h +++ b/Eigen/src/Core/arch/NEON/MatrixProduct.h @@ -250,13 +250,14 @@ struct Accumulator } } - EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col) + template + EIGEN_STRONG_INLINE void store(const DataMapper& dest, Index row, Index col, ResScalar alpha, const ResPacket& pAlpha) { for(auto i = 0; i < M; i++) { for(auto j = 0; j < N; j++) { - dest(row + i, col + j) += dt[i][j]; + dest(row + i, col + j) += alpha*dt[i][j]; } } } @@ -337,8 +338,8 @@ struct DepthLoopStruct { mkt(lhsPackMap, rhsPackMap, rowIdx, colIdx, depthIdx, acc); } - acc.scale(alpha, pAlpha); - acc.store(res, rowIdx, colIdx); + //acc.scale(alpha, pAlpha); + acc.store(res, rowIdx, colIdx, alpha, pAlpha); depthLS(rowIdx, colIdx, depthIdx, res, rows, depth, cols, alpha, pAlpha, lhsPackMap, rhsPackMap); }