From 771a4fc20bf92bada9ca2f0a7e24266e253032b7 Mon Sep 17 00:00:00 2001 From: MerryMage Date: Wed, 25 Jul 2018 13:19:48 +0100 Subject: [PATCH] IR: Implement FPVectorMulAdd --- .../emit_x64_vector_floating_point.cpp | 222 +++++++++++++++--- src/frontend/ir/ir_emitter.cpp | 11 + src/frontend/ir/ir_emitter.h | 3 +- src/frontend/ir/opcodes.inc | 2 + 4 files changed, 207 insertions(+), 31 deletions(-) diff --git a/src/backend_x64/emit_x64_vector_floating_point.cpp b/src/backend_x64/emit_x64_vector_floating_point.cpp index a9a79708..fa49da74 100644 --- a/src/backend_x64/emit_x64_vector_floating_point.cpp +++ b/src/backend_x64/emit_x64_vector_floating_point.cpp @@ -41,24 +41,38 @@ static T ChooseOnFsize(T f32, T f64) { #define FCODE(NAME) (code.*ChooseOnFsize(&Xbyak::CodeGenerator::NAME##s, &Xbyak::CodeGenerator::NAME##d)) -template class Indexer, size_t... argi> -static auto GetRuntimeNaNFunction(std::index_sequence) { - auto result = [](std::array, sizeof...(argi) + 1>& values) { - VectorArray& result = values[0]; - for (size_t elementi = 0; elementi < result.size(); ++elementi) { - const auto current_values = Indexer{}(elementi, values[argi + 1]...); - if (auto r = FP::ProcessNaNs(std::get(current_values)...)) { - result[elementi] = *r; - } else if (FP::IsNaN(result[elementi])) { - result[elementi] = FP::FPInfo::DefaultNaN(); - } - } - }; - return static_cast*>(result); -} +template class Indexer, size_t narg> +struct NaNHandler { +private: + template + static auto GetDefaultImpl(std::index_sequence) { + using FPT = mp::unsigned_integer_of_size; -template class Indexer> -static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array xmms, const Xbyak::Xmm& nan_mask) { + auto result = [](std::array, sizeof...(argi) + 1>& values) { + VectorArray& result = values[0]; + for (size_t elementi = 0; elementi < result.size(); ++elementi) { + const auto current_values = Indexer{}(elementi, values[argi + 1]...); + if (auto r = FP::ProcessNaNs(std::get(current_values)...)) { + result[elementi] = *r; + } else if (FP::IsNaN(result[elementi])) { + result[elementi] = FP::FPInfo::DefaultNaN(); + } + } + }; + + return static_cast*>(result); + } + +public: + static auto GetDefault() { + return GetDefaultImpl(std::make_index_sequence{}); + } + + using function_type = mp::return_type_t; +}; + +template +static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array xmms, const Xbyak::Xmm& nan_mask, NaNHandler nan_handler) { static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { @@ -91,8 +105,7 @@ static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array; - code.CallFunction(GetRuntimeNaNFunction(std::make_index_sequence{})); + code.CallFunction(nan_handler); code.movaps(result, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]); code.add(rsp, stack_space + ABI_SHADOW_SPACE); @@ -155,13 +168,13 @@ struct PairedLowerIndexer { }; template class Indexer, typename Function> -static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn) { +static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn, typename NaNHandler::function_type nan_handler = NaNHandler::GetDefault()) { static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); if (!ctx.AccurateNaN() || ctx.FPSCR_DN()) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); - Xbyak::Xmm xmm_a = ctx.reg_alloc.UseScratchXmm(args[0]); - Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); if constexpr (std::is_member_function_pointer_v) { (code.*fn)(xmm_a, xmm_b); @@ -170,8 +183,8 @@ static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR:: } if (ctx.FPSCR_DN()) { - Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); - Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); code.pcmpeqw(tmp, tmp); code.movaps(nan_mask, xmm_a); FCODE(cmpordp)(nan_mask, nan_mask); @@ -187,10 +200,10 @@ static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR:: auto args = ctx.reg_alloc.GetArgumentInfo(inst); - Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); - Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); - Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); - Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); code.movaps(nan_mask, xmm_b); code.movaps(result, xmm_a); @@ -202,7 +215,63 @@ static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR:: } FCODE(cmpunordp)(nan_mask, result); - HandleNaNs(code, ctx, {result, xmm_a, xmm_b}, nan_mask); + HandleNaNs(code, ctx, {result, xmm_a, xmm_b}, nan_mask, nan_handler); + + ctx.reg_alloc.DefineValue(inst, result); +} + +template class Indexer, typename Function> +static void EmitFourOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn, typename NaNHandler::function_type nan_handler = NaNHandler::GetDefault()) { + static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); + + if (!ctx.AccurateNaN() || ctx.FPSCR_DN()) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]); + + if constexpr (std::is_member_function_pointer_v) { + (code.*fn)(xmm_a, xmm_b, xmm_c); + } else { + fn(xmm_a, xmm_b, xmm_c); + } + + if (ctx.FPSCR_DN()) { + const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + code.pcmpeqw(tmp, tmp); + code.movaps(nan_mask, xmm_a); + FCODE(cmpordp)(nan_mask, nan_mask); + code.andps(xmm_a, nan_mask); + code.xorps(nan_mask, tmp); + code.andps(nan_mask, fsize == 32 ? code.MConst(xword, 0x7fc0'0000'7fc0'0000, 0x7fc0'0000'7fc0'0000) : code.MConst(xword, 0x7ff8'0000'0000'0000, 0x7ff8'0000'0000'0000)); + code.orps(xmm_a, nan_mask); + } + + ctx.reg_alloc.DefineValue(inst, xmm_a); + return; + } + + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]); + const Xbyak::Xmm nan_mask = ctx.reg_alloc.ScratchXmm(); + + code.movaps(nan_mask, xmm_b); + code.movaps(result, xmm_a); + FCODE(cmpunordp)(nan_mask, xmm_a); + FCODE(cmpunordp)(nan_mask, xmm_c); + if constexpr (std::is_member_function_pointer_v) { + (code.*fn)(result, xmm_b, xmm_c); + } else { + fn(result, xmm_b, xmm_c); + } + FCODE(cmpunordp)(nan_mask, result); + + HandleNaNs(code, ctx, {result, xmm_a, xmm_b, xmm_c}, nan_mask, nan_handler); ctx.reg_alloc.DefineValue(inst, result); } @@ -250,7 +319,7 @@ inline void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* i code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 3 * 16]); code.mov(code.ABI_PARAM4.cvt32(), ctx.FPCR()); code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); - code.mov(qword[rsp + ABI_SHADOW_SPACE + 0 * 16], rax); + code.mov(qword[rsp + ABI_SHADOW_SPACE + 0], rax); #else constexpr u32 stack_space = 3 * 16; code.sub(rsp, stack_space + ABI_SHADOW_SPACE); @@ -276,6 +345,54 @@ inline void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* i ctx.reg_alloc.DefineValue(inst, xmm0); } +template +inline void EmitFourOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Lambda lambda) { + const auto fn = static_cast*>(lambda); + + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const Xbyak::Xmm arg1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm arg2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm arg3 = ctx.reg_alloc.UseXmm(args[2]); + ctx.reg_alloc.EndOfAllocScope(); + ctx.reg_alloc.HostCall(nullptr); + +#ifdef _WIN32 + constexpr u32 stack_space = 5 * 16; + code.sub(rsp, stack_space + ABI_SHADOW_SPACE); + code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 1 * 16]); + code.lea(code.ABI_PARAM2, ptr[rsp + ABI_SHADOW_SPACE + 2 * 16]); + code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 3 * 16]); + code.lea(code.ABI_PARAM4, ptr[rsp + ABI_SHADOW_SPACE + 4 * 16]); + code.mov(qword[rsp + ABI_SHADOW_SPACE + 0], ctx.FPCR()); + code.lea(rax, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); + code.mov(qword[rsp + ABI_SHADOW_SPACE + 8], rax); +#else + constexpr u32 stack_space = 4 * 16; + code.sub(rsp, stack_space + ABI_SHADOW_SPACE); + code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 0 * 16]); + code.lea(code.ABI_PARAM2, ptr[rsp + ABI_SHADOW_SPACE + 1 * 16]); + code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 2 * 16]); + code.lea(code.ABI_PARAM4, ptr[rsp + ABI_SHADOW_SPACE + 3 * 16]); + code.mov(code.ABI_PARAM5.cvt32(), ctx.FPCR()); + code.lea(code.ABI_PARAM6, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); +#endif + + code.movaps(xword[code.ABI_PARAM2], arg1); + code.movaps(xword[code.ABI_PARAM3], arg2); + code.movaps(xword[code.ABI_PARAM4], arg3); + code.CallFunction(fn); + +#ifdef _WIN32 + code.movaps(xmm0, xword[rsp + ABI_SHADOW_SPACE + 1 * 16]); +#else + code.movaps(xmm0, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]); +#endif + + code.add(rsp, stack_space + ABI_SHADOW_SPACE); + + ctx.reg_alloc.DefineValue(inst, xmm0); +} + void EmitX64::EmitFPVectorAbs16(EmitContext& ctx, IR::Inst* inst) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); @@ -393,6 +510,51 @@ void EmitX64::EmitFPVectorMul64(EmitContext& ctx, IR::Inst* inst) { EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd); } +template +void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { + using FPT = mp::unsigned_integer_of_size; + + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA)) { + auto x64_instruction = fsize == 32 ? &Xbyak::CodeGenerator::vfmadd231ps : &Xbyak::CodeGenerator::vfmadd231pd; + EmitFourOpVectorOperation(code, ctx, inst, x64_instruction, + static_cast, 4>& values)>( + [](std::array, 4>& values) { + VectorArray& result = values[0]; + const VectorArray& a = values[1]; + const VectorArray& b = values[2]; + const VectorArray& c = values[3]; + for (size_t i = 0; i < result.size(); i++) { + if (FP::IsQNaN(a[i]) && ((FP::IsInf(b[i]) && FP::IsZero(c[i])) || (FP::IsZero(b[i]) && FP::IsInf(c[i])))) { + result[i] = FP::FPInfo::DefaultNaN(); + } else if (auto r = FP::ProcessNaNs(a[i], b[i], c[i])) { + result[i] = *r; + } else if (FP::IsNaN(result[i])) { + result[i] = FP::FPInfo::DefaultNaN(); + } + } + } + ) + ); + return; + } + + EmitFourOpFallback(code, ctx, inst, + [](VectorArray& result, const VectorArray& addend, const VectorArray& op1, const VectorArray& op2, FP::FPCR fpcr, FP::FPSR& fpsr) { + for (size_t i = 0; i < result.size(); i++) { + result[i] = FP::FPMulAdd(addend[i], op1[i], op2[i], fpcr, fpsr); + } + } + ); +} + +void EmitX64::EmitFPVectorMulAdd32(EmitContext& ctx, IR::Inst* inst) { + EmitFPVectorMulAdd<32>(code, ctx, inst); +} + +void EmitX64::EmitFPVectorMulAdd64(EmitContext& ctx, IR::Inst* inst) { + EmitFPVectorMulAdd<64>(code, ctx, inst); +} + void EmitX64::EmitFPVectorPairedAdd32(EmitContext& ctx, IR::Inst* inst) { EmitThreeOpVectorOperation<32, PairedIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::haddps); } diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 9234b660..9a94f744 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -1696,6 +1696,17 @@ U128 IREmitter::FPVectorMul(size_t esize, const U128& a, const U128& b) { return {}; } +U128 IREmitter::FPVectorMulAdd(size_t esize, const U128& a, const U128& b, const U128& c) { + switch (esize) { + case 32: + return Inst(Opcode::FPVectorMulAdd32, a, b, c); + case 64: + return Inst(Opcode::FPVectorMulAdd64, a, b, c); + } + UNREACHABLE(); + return {}; +} + U128 IREmitter::FPVectorPairedAdd(size_t esize, const U128& a, const U128& b) { switch (esize) { case 32: diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index e5f73c60..2904d8c2 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -267,7 +267,7 @@ public: U32U64 FPMin(const U32U64& a, const U32U64& b, bool fpscr_controlled); U32U64 FPMinNumeric(const U32U64& a, const U32U64& b, bool fpscr_controlled); U32U64 FPMul(const U32U64& a, const U32U64& b, bool fpscr_controlled); - U32U64 FPMulAdd(const U32U64& a, const U32U64& b, const U32U64& c, bool fpscr_controlled); + U32U64 FPMulAdd(const U32U64& addend, const U32U64& op1, const U32U64& op2, bool fpscr_controlled); U32U64 FPNeg(const U32U64& a); U32U64 FPRoundInt(const U32U64& a, FP::RoundingMode rounding, bool exact); U32U64 FPRSqrtEstimate(const U32U64& a); @@ -300,6 +300,7 @@ public: U128 FPVectorGreater(size_t esize, const U128& a, const U128& b); U128 FPVectorGreaterEqual(size_t esize, const U128& a, const U128& b); U128 FPVectorMul(size_t esize, const U128& a, const U128& b); + U128 FPVectorMulAdd(size_t esize, const U128& addend, const U128& op1, const U128& op2); U128 FPVectorPairedAdd(size_t esize, const U128& a, const U128& b); U128 FPVectorPairedAddLower(size_t esize, const U128& a, const U128& b); U128 FPVectorRSqrtEstimate(size_t esize, const U128& a); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index b09f6ba2..ace84304 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -441,6 +441,8 @@ OPCODE(FPVectorGreaterEqual32, T::U128, T::U128, OPCODE(FPVectorGreaterEqual64, T::U128, T::U128, T::U128 ) OPCODE(FPVectorMul32, T::U128, T::U128, T::U128 ) OPCODE(FPVectorMul64, T::U128, T::U128, T::U128 ) +OPCODE(FPVectorMulAdd32, T::U128, T::U128, T::U128, T::U128 ) +OPCODE(FPVectorMulAdd64, T::U128, T::U128, T::U128, T::U128 ) OPCODE(FPVectorPairedAddLower32, T::U128, T::U128, T::U128 ) OPCODE(FPVectorPairedAddLower64, T::U128, T::U128, T::U128 ) OPCODE(FPVectorPairedAdd32, T::U128, T::U128, T::U128 )