Fix problem with array conversions BF16->F32 in Power.
This commit is contained in:
		
							parent
							
								
									77b48c440e
								
							
						
					
					
						commit
						4a03409569
					
				| @ -189,11 +189,43 @@ void gemmMMAbfloat16(const DataMapper& res, const bfloat16* blockA, const bfloat | ||||
|   ei_declare_aligned_stack_constructed_variable(float, result, cols*rows, 0); | ||||
| 
 | ||||
|   typedef typename DataMapper::LinearMapper LinearMapper; | ||||
|   Packet8us z = pset1<Packet8us>(0); | ||||
|   for(Index j = 0; j < cols; j++){ | ||||
|     const LinearMapper res2 = res.getLinearMapper(0, j); | ||||
|     float *result2 = result + j*rows; | ||||
|     BFLOAT16_UNROLL | ||||
|     for(Index i = 0; i < rows; i++){ | ||||
|     Index i = 0; | ||||
|     for(; i + 32 <= rows; i+=32){ | ||||
|       Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i +  0).m_val; | ||||
|       Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i +  8).m_val; | ||||
|       Packet8us r32_2 = res2.template loadPacket<Packet8bf>(i + 16).m_val; | ||||
|       Packet8us r32_3 = res2.template loadPacket<Packet8bf>(i + 24).m_val; | ||||
|       pstore(result2 + i +  0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0))); | ||||
|       pstore(result2 + i +  4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0))); | ||||
|       pstore(result2 + i +  8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1))); | ||||
|       pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1))); | ||||
|       pstore(result2 + i + 16, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_2))); | ||||
|       pstore(result2 + i + 20, reinterpret_cast<Packet4f>(vec_mergel(z, r32_2))); | ||||
|       pstore(result2 + i + 24, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_3))); | ||||
|       pstore(result2 + i + 28, reinterpret_cast<Packet4f>(vec_mergel(z, r32_3))); | ||||
|     } | ||||
|     for(; i + 16 <= rows; i+=16){ | ||||
|       Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i +  0).m_val; | ||||
|       Packet8us r32_1 = res2.template loadPacket<Packet8bf>(i +  8).m_val; | ||||
|       pstore(result2 + i +  0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0))); | ||||
|       pstore(result2 + i +  4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0))); | ||||
|       pstore(result2 + i +  8, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_1))); | ||||
|       pstore(result2 + i + 12, reinterpret_cast<Packet4f>(vec_mergel(z, r32_1))); | ||||
|     } | ||||
|     for(; i + 8 <= rows; i+=8){ | ||||
|       Packet8us r32_0 = res2.template loadPacket<Packet8bf>(i +  0).m_val; | ||||
|       pstore(result2 + i +  0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0))); | ||||
|       pstore(result2 + i +  4, reinterpret_cast<Packet4f>(vec_mergel(z, r32_0))); | ||||
|     } | ||||
|     for(; i + 4 <= rows; i+=4){ | ||||
|       Packet8us r32_0 = res2.template loadPacketPartial<Packet8bf>(i +  0, 4).m_val; | ||||
|       pstore(result2 + i +  0, reinterpret_cast<Packet4f>(vec_mergeh(z, r32_0))); | ||||
|     } | ||||
|     for(; i < rows; i++){ | ||||
|       result2[i] = res2(i); | ||||
|     } | ||||
|   } | ||||
|  | ||||
		Loading…
	
		Reference in New Issue
	
	Block a user
	 Chip Kerchner
						Chip Kerchner