From fda1373a1581a8324b6ab563c4e6d66db187e306 Mon Sep 17 00:00:00 2001 From: Chip Kerchner Date: Wed, 3 May 2023 20:12:50 +0000 Subject: [PATCH] Fix ColMajor BF16 GEMV for when vector is RowMajor --- .../Core/arch/AltiVec/MatrixProductCommon.h | 3 + .../arch/AltiVec/MatrixProductMMAbfloat16.h | 49 ++++++++------- .../Core/arch/AltiVec/MatrixVectorProduct.h | 61 ++++++++++++------- Eigen/src/Core/arch/AltiVec/PacketMath.h | 56 ++++++++++------- 4 files changed, 104 insertions(+), 65 deletions(-) diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h index fe135dc13..226f425b4 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductCommon.h @@ -102,6 +102,9 @@ EIGEN_ALWAYS_INLINE void outputVecColResults(Packet4f (&acc)[num_acc][size], flo template EIGEN_ALWAYS_INLINE void outputVecResults(Packet4f (&acc)[num_acc][size], float *result, Packet4f pAlpha); +template> +EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j); + template EIGEN_ALWAYS_INLINE Packet ploadLhs(const __UNPACK_TYPE__(Packet)* lhs); diff --git a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h index 0944c2d16..011d68e69 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixProductMMAbfloat16.h @@ -383,12 +383,12 @@ EIGEN_ALWAYS_INLINE void multVec(__vector_quad (&quad_acc)[num_acc], Packet8bf ( } } -template +template EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __vector_quad (&quad_acc)[num_acc]) { Packet8bf a0[num_acc]; Packet8bf b1 = pset1(Eigen::bfloat16(0)); - Packet8bf b0 = rhs.template loadPacket(j + 0); + Packet8bf b0 = loadColData(rhs, j); if (zero) { b0 = vec_mergeh(b0.m_val, b1.m_val); @@ -405,7 +405,7 @@ EIGEN_ALWAYS_INLINE void vecColLoop(Index j, LhsMapper& lhs, RhsMapper& rhs, __v #define MAX_BFLOAT16_VEC_ACC 8 -template +template void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { constexpr Index step = (num_acc * 4); @@ -420,10 +420,10 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa LhsMapper lhs2 = lhs.getSubMapper(row, 0); for(Index j = 0; j + 2 <= cend; j += 2) { - vecColLoop(j, lhs2, rhs, quad_acc); + vecColLoop(j, lhs2, rhs, quad_acc); } if (cend & 1) { - vecColLoop(cend - 1, lhs2, rhs, quad_acc); + vecColLoop(cend - 1, lhs2, rhs, quad_acc); } disassembleAccumulators(quad_acc, acc); @@ -434,59 +434,59 @@ void colVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMa } while(multiIters && (step <= rows - (row += step))); } -template +template EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { if (MAX_BFLOAT16_VEC_ACC > num_acc) { - colVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); } } -template +template EIGEN_ALWAYS_INLINE void colVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { switch ((rows - row) >> 2) { case 7: - colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 6: - colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 5: - colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 4: - colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 3: - colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 2: - colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 1: - colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; default: if (extraRows) { - colVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result); } break; } } -template +template EIGEN_ALWAYS_INLINE void calcVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { Index row = 0; if (rows >= (MAX_BFLOAT16_VEC_ACC * 4)) { - colVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); result += row; } if (rows & 3) { - colVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); } else { - colVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); + colVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); } } @@ -524,8 +524,13 @@ void gemvMMA_bfloat16_col( Index jend = numext::mini(j2 + block_cols, cols); LhsMapper lhs2 = lhs.getSubMapper(0, j2); - LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); - calcVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); + if (rhs.stride() == 1) { + LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); + calcVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); + } else { + RhsMapper rhs3 = rhs2.getSubMapper(j2, 0); + calcVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); + } } convertArrayPointerF32toBF16(result, rows, res); diff --git a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h index 62840a3e9..480a48abc 100644 --- a/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h +++ b/Eigen/src/Core/arch/AltiVec/MatrixVectorProduct.h @@ -521,11 +521,23 @@ EIGEN_ALWAYS_INLINE void multVecVSX(Packet4f (&acc)[num_acc][2], Packet4f (&a0)[ } } -template +template = true> +EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j) +{ + return rhs.template loadPacket(j + 0); +} + +template = true> +EIGEN_ALWAYS_INLINE Packet8bf loadColData(RhsMapper& rhs, Index j) +{ + return pgather(&rhs(j + 0, 0), rhs.stride()); +} + +template EIGEN_ALWAYS_INLINE void vecColLoopVSX(Index j, LhsMapper& lhs, RhsMapper& rhs, Packet4f (&acc)[num_acc][2]) { Packet4f a0[num_acc][2], b0[2]; - Packet8bf b2 = rhs.template loadPacket(j + 0); + Packet8bf b2 = loadColData(rhs, j); b0[0] = oneConvertBF16Perm(b2.m_val, p16uc_MERGE16_32_V1); if (!zero) { @@ -551,7 +563,7 @@ EIGEN_ALWAYS_INLINE void addResultsVSX(Packet4f (&acc)[num_acc][2]) // Uses 2X the accumulators or 4X the number of VSX registers #define MAX_BFLOAT16_VEC_ACC_VSX 8 -template +template void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { constexpr Index step = (num_acc * 4); @@ -565,10 +577,10 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh LhsMapper lhs2 = lhs.getSubMapper(row, 0); for(Index j = 0; j + 2 <= cend; j += 2) { - vecColLoopVSX(j, lhs2, rhs, acc); + vecColLoopVSX(j, lhs2, rhs, acc); } if (cend & 1) { - vecColLoopVSX(cend - 1, lhs2, rhs, acc); + vecColLoopVSX(cend - 1, lhs2, rhs, acc); } addResultsVSX(acc); @@ -579,59 +591,59 @@ void colVSXVecColLoopBody(Index& row, Index cend, Index rows, LhsMapper& lhs, Rh } while(multiIters && (step <= rows - (row += step))); } -template +template EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtraN(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { if (MAX_BFLOAT16_VEC_ACC_VSX > num_acc) { - colVSXVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); } } -template +template EIGEN_ALWAYS_INLINE void colVSXVecColLoopBodyExtra(Index& row, Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { switch ((rows - row) >> 2) { case 7: - colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtraN<7, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 6: - colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtraN<6, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 5: - colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtraN<5, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 4: - colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtraN<4, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 3: - colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtraN<3, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 2: - colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtraN<2, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; case 1: - colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows>(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtraN<1, LhsMapper, RhsMapper, extraRows, linear>(row, cend, rows, lhs, rhs, pAlpha, result); break; default: if (extraRows) { - colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true>(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBody<1, LhsMapper, RhsMapper, true, linear>(row, cend, rows, lhs, rhs, pAlpha, result); } break; } } -template +template EIGEN_ALWAYS_INLINE void calcVSXVecColLoops(Index cend, Index rows, LhsMapper& lhs, RhsMapper& rhs, const Packet4f pAlpha, float *result) { Index row = 0; if (rows >= (MAX_BFLOAT16_VEC_ACC_VSX * 4)) { - colVSXVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBody(row, cend, rows, lhs, rhs, pAlpha, result); result += row; } if (rows & 3) { - colVSXVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); } else { - colVSXVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); + colVSXVecColLoopBodyExtra(row, cend, rows, lhs, rhs, pAlpha, result); } } @@ -724,8 +736,13 @@ void gemv_bfloat16_col( Index jend = numext::mini(j2 + block_cols, cols); LhsMapper lhs2 = lhs.getSubMapper(0, j2); - LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); - calcVSXVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); + if (rhs.stride() == 1) { + LinearMapper rhs3 = rhs2.getLinearMapper(j2, 0); + calcVSXVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); + } else { + RhsMapper rhs3 = rhs2.getSubMapper(j2, 0); + calcVSXVecColLoops(jend - j2, rows, lhs2, rhs3, pAlpha, result); + } } convertArrayPointerF32toBF16VSX(result, rows, res); diff --git a/Eigen/src/Core/arch/AltiVec/PacketMath.h b/Eigen/src/Core/arch/AltiVec/PacketMath.h index b24796e4a..7e0c75918 100644 --- a/Eigen/src/Core/arch/AltiVec/PacketMath.h +++ b/Eigen/src/Core/arch/AltiVec/PacketMath.h @@ -796,12 +796,20 @@ template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet pgather_c { EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[unpacket_traits::size]; eigen_internal_assert(n <= unpacket_traits::size && "number of elements will gather past end of packet"); - LOAD_STORE_UNROLL_16 - for (Index i = 0; i < n; i++) { - a[i] = from[i*stride]; + if (stride == 1) { + if (n == unpacket_traits::size) { + return ploadu(from); + } else { + return ploadu_partial(from, n); + } + } else { + LOAD_STORE_UNROLL_16 + for (Index i = 0; i < n; i++) { + a[i] = from[i*stride]; + } + // Leave rest of the array uninitialized + return pload_ignore(a); } - // Leave rest of the array uninitialized - return pload_ignore(a); } template<> EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE Packet4f pgather(const float* from, Index stride) @@ -878,10 +886,18 @@ template EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE void pscatter_co { EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) a[unpacket_traits::size]; eigen_internal_assert(n <= unpacket_traits::size && "number of elements will scatter past end of packet"); - pstore<__UNPACK_TYPE__(Packet)>(a, from); - LOAD_STORE_UNROLL_16 - for (Index i = 0; i < n; i++) { - to[i*stride] = a[i]; + if (stride == 1) { + if (n == unpacket_traits::size) { + return pstoreu(to, from); + } else { + return pstoreu_partial(to, from, n); + } + } else { + pstore<__UNPACK_TYPE__(Packet)>(a, from); + LOAD_STORE_UNROLL_16 + for (Index i = 0; i < n; i++) { + to[i*stride] = a[i]; + } } } @@ -1256,15 +1272,14 @@ template EIGEN_ALWAYS_INLINE Packet ploadu_partial_common(const return vec_xl_len(const_cast<__UNPACK_TYPE__(Packet)*>(from), n * size); #else if (n) { + Index n2 = n * size; + if (16 <= n2) { + return ploadu(from); + } EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) load[packet_size]; unsigned char* load2 = reinterpret_cast(load); unsigned char* from2 = reinterpret_cast(const_cast<__UNPACK_TYPE__(Packet)*>(from)); - Index n2 = n * size; - if (16 <= n2) { - pstore(load2, ploadu(from2)); - } else { - memcpy((void *)load2, (void *)from2, n2); - } + memcpy((void *)load2, (void *)from2, n2); return pload_ignore(load); } else { return Packet(pset1(0)); @@ -1432,16 +1447,15 @@ template EIGEN_ALWAYS_INLINE void pstoreu_partial_common(__UNPA vec_xst_len(from, to, n * size); #else if (n) { + Index n2 = n * size; + if (16 <= n2) { + pstoreu(to, from); + } EIGEN_ALIGN16 __UNPACK_TYPE__(Packet) store[packet_size]; pstore(store, from); unsigned char* store2 = reinterpret_cast(store); unsigned char* to2 = reinterpret_cast(to); - Index n2 = n * size; - if (16 <= n2) { - pstoreu(to2, pload(store2)); - } else { - memcpy((void *)to2, (void *)store2, n2); - } + memcpy((void *)to2, (void *)store2, n2); } #endif }