diff --git a/src/backend/x64/emit_x64_vector_floating_point.cpp b/src/backend/x64/emit_x64_vector_floating_point.cpp index 1cf41322..f1a717b5 100644 --- a/src/backend/x64/emit_x64_vector_floating_point.cpp +++ b/src/backend/x64/emit_x64_vector_floating_point.cpp @@ -151,6 +151,12 @@ Xbyak::Address GetSmallestNormalVector(BlockOfCode& code) { return GetVectorOf(code); } +template value> +Xbyak::Address GetVectorOf(BlockOfCode& code) { + using FPT = mp::unsigned_integer_of_size; + return GetVectorOf()>(code); +} + template void ForceToDefaultNaN(BlockOfCode& code, EmitContext& ctx, Xbyak::Xmm result) { if (ctx.FPSCR_DN()) { @@ -357,16 +363,9 @@ void EmitTwoOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Lamb } template -void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Lambda lambda) { +void EmitThreeOpFallbackWithoutRegAlloc(BlockOfCode& code, EmitContext& ctx, Xbyak::Xmm result, Xbyak::Xmm arg1, Xbyak::Xmm arg2, Lambda lambda) { const auto fn = static_cast*>(lambda); - auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm arg1 = ctx.reg_alloc.UseXmm(args[0]); - const Xbyak::Xmm arg2 = ctx.reg_alloc.UseXmm(args[1]); - const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); - ctx.reg_alloc.EndOfAllocScope(); - ctx.reg_alloc.HostCall(nullptr); - #ifdef _WIN32 constexpr u32 stack_space = 4 * 16; code.sub(rsp, stack_space + ABI_SHADOW_SPACE); @@ -397,6 +396,18 @@ void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, La #endif code.add(rsp, stack_space + ABI_SHADOW_SPACE); +} + +template +void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Lambda lambda) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const Xbyak::Xmm arg1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm arg2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + ctx.reg_alloc.EndOfAllocScope(); + ctx.reg_alloc.HostCall(nullptr); + + EmitThreeOpFallbackWithoutRegAlloc(code, ctx, result, arg1, arg2, lambda); ctx.reg_alloc.DefineValue(inst, result); } @@ -821,21 +832,57 @@ void EmitX64::EmitFPVectorRecipEstimate64(EmitContext& ctx, IR::Inst* inst) { EmitRecipEstimate(code, ctx, inst); } -template +template static void EmitRecipStepFused(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& op1, const VectorArray& op2, FP::FPCR fpcr, FP::FPSR& fpsr) { + using FPT = mp::unsigned_integer_of_size; + + const auto fallback_fn = [](VectorArray& result, const VectorArray& op1, const VectorArray& op2, FP::FPCR fpcr, FP::FPSR& fpsr) { for (size_t i = 0; i < result.size(); i++) { result[i] = FP::FPRecipStepFused(op1[i], op2[i], fpcr, fpsr); } - }); + }; + + if (code.DoesCpuSupport(Xbyak::util::Cpu::tFMA) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm operand1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm operand2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + + Xbyak::Label end, fallback; + + code.movaps(result, GetVectorOf(code)); + FCODE(vfnmadd231p)(result, operand1, operand2); + + FCODE(vcmpunordp)(tmp, result, result); + code.vptest(tmp, tmp); + code.jnz(fallback, code.T_NEAR); + code.L(end); + + code.SwitchToFarCode(); + code.L(fallback); + code.sub(rsp, 8); + ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + EmitThreeOpFallbackWithoutRegAlloc(code, ctx, result, operand1, operand2, fallback_fn); + ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); + code.add(rsp, 8); + code.jmp(end, code.T_NEAR); + code.SwitchToNearCode(); + + ctx.reg_alloc.DefineValue(inst, result); + return; + } + + EmitThreeOpFallback(code, ctx, inst, fallback_fn); } void EmitX64::EmitFPVectorRecipStepFused32(EmitContext& ctx, IR::Inst* inst) { - EmitRecipStepFused(code, ctx, inst); + EmitRecipStepFused<32>(code, ctx, inst); } void EmitX64::EmitFPVectorRecipStepFused64(EmitContext& ctx, IR::Inst* inst) { - EmitRecipStepFused(code, ctx, inst); + EmitRecipStepFused<64>(code, ctx, inst); } template