emit_x64_vector_floating_point: FPVectorMulAdd: Minimize full fallback

This commit is contained in:
Merry 2023-08-28 12:58:09 +01:00
parent ceea80dd59
commit adac93f12e

View file

@ -381,7 +381,7 @@ void EmitTwoOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* ins
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
} }
enum CheckInputNaN { enum class CheckInputNaN {
Yes, Yes,
No, No,
}; };
@ -540,7 +540,12 @@ void EmitThreeOpFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, La
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
} }
template<typename Lambda> enum class LoadPreviousResult {
Yes,
No,
};
template<LoadPreviousResult load_previous_result = LoadPreviousResult::No, typename Lambda>
void EmitFourOpFallbackWithoutRegAlloc(BlockOfCode& code, EmitContext& ctx, Xbyak::Xmm result, Xbyak::Xmm arg1, Xbyak::Xmm arg2, Xbyak::Xmm arg3, Lambda lambda, bool fpcr_controlled) { void EmitFourOpFallbackWithoutRegAlloc(BlockOfCode& code, EmitContext& ctx, Xbyak::Xmm result, Xbyak::Xmm arg1, Xbyak::Xmm arg2, Xbyak::Xmm arg3, Lambda lambda, bool fpcr_controlled) {
const auto fn = static_cast<mcl::equivalent_function_type<Lambda>*>(lambda); const auto fn = static_cast<mcl::equivalent_function_type<Lambda>*>(lambda);
@ -565,6 +570,9 @@ void EmitFourOpFallbackWithoutRegAlloc(BlockOfCode& code, EmitContext& ctx, Xbya
code.lea(code.ABI_PARAM6, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]); code.lea(code.ABI_PARAM6, code.ptr[code.r15 + code.GetJitStateInfo().offsetof_fpsr_exc]);
#endif #endif
if constexpr (load_previous_result == LoadPreviousResult::Yes) {
code.movaps(xword[code.ABI_PARAM1], result);
}
code.movaps(xword[code.ABI_PARAM2], arg1); code.movaps(xword[code.ABI_PARAM2], arg1);
code.movaps(xword[code.ABI_PARAM3], arg2); code.movaps(xword[code.ABI_PARAM3], arg2);
code.movaps(xword[code.ABI_PARAM4], arg3); code.movaps(xword[code.ABI_PARAM4], arg3);
@ -1290,6 +1298,31 @@ void EmitX64::EmitFPVectorMul64(EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd); EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd);
} }
template<typename FPT, bool needs_rounding_correction, bool needs_nan_correction>
static void EmitFPVectorMulAddFallback(VectorArray<FPT>& result, const VectorArray<FPT>& addend, const VectorArray<FPT>& op1, const VectorArray<FPT>& op2, FP::FPCR fpcr, [[maybe_unused]] FP::FPSR& fpsr) {
for (size_t i = 0; i < result.size(); i++) {
if constexpr (needs_rounding_correction) {
constexpr FPT non_sign_mask = FP::FPInfo<FPT>::exponent_mask | FP::FPInfo<FPT>::mantissa_mask;
constexpr FPT smallest_normal_number = FP::FPValue<FPT, false, FP::FPInfo<FPT>::exponent_min, 1>();
if ((result[i] & non_sign_mask) == smallest_normal_number) {
result[i] = FP::FPMulAdd<FPT>(addend[i], op1[i], op2[i], fpcr, fpsr);
continue;
}
}
if constexpr (needs_nan_correction) {
if (FP::IsNaN(result[i])) {
if (FP::IsQNaN(addend[i]) && ((FP::IsZero(op1[i], fpcr) && FP::IsInf(op2[i])) || (FP::IsInf(op1[i]) && FP::IsZero(op2[i], fpcr)))) {
result[i] = FP::FPInfo<FPT>::DefaultNaN();
} else if (auto r = FP::ProcessNaNs(addend[i], op1[i], op2[i])) {
result[i] = *r;
} else {
result[i] = FP::FPInfo<FPT>::DefaultNaN();
}
}
}
}
}
template<size_t fsize> template<size_t fsize>
void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
using FPT = mcl::unsigned_integer_of_size<fsize>; using FPT = mcl::unsigned_integer_of_size<fsize>;
@ -1301,9 +1334,12 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
}; };
if constexpr (fsize != 16) { if constexpr (fsize != 16) {
if (code.HasHostFeature(HostFeature::FMA | HostFeature::AVX) && ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN)) { const bool fpcr_controlled = inst->GetArg(3).GetU1();
const bool needs_rounding_correction = ctx.FPCR(fpcr_controlled).FZ();
const bool needs_nan_correction = !(ctx.FPCR(fpcr_controlled).DN() || ctx.HasOptimization(OptimizationFlag::Unsafe_InaccurateNaN));
if (code.HasHostFeature(HostFeature::FMA) && !needs_rounding_correction && !needs_nan_correction) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst); auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const bool fpcr_controlled = args[3].GetImmediateU1();
const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]);
@ -1311,6 +1347,7 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
MaybeStandardFPSCRValue(code, ctx, fpcr_controlled, [&] { MaybeStandardFPSCRValue(code, ctx, fpcr_controlled, [&] {
FCODE(vfmadd231p)(result, xmm_b, xmm_c); FCODE(vfmadd231p)(result, xmm_b, xmm_c);
ForceToDefaultNaN<fsize>(code, ctx.FPCR(fpcr_controlled), result);
}); });
ctx.reg_alloc.DefineValue(inst, result); ctx.reg_alloc.DefineValue(inst, result);
@ -1319,12 +1356,11 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
if (code.HasHostFeature(HostFeature::FMA | HostFeature::AVX)) { if (code.HasHostFeature(HostFeature::FMA | HostFeature::AVX)) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst); auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const bool fpcr_controlled = args[3].GetImmediateU1();
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]); const Xbyak::Xmm xmm_a = ctx.reg_alloc.UseXmm(args[0]);
const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]); const Xbyak::Xmm xmm_b = ctx.reg_alloc.UseXmm(args[1]);
const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]); const Xbyak::Xmm xmm_c = ctx.reg_alloc.UseXmm(args[2]);
const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm();
const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
SharedLabel end = GenSharedLabel(), fallback = GenSharedLabel(); SharedLabel end = GenSharedLabel(), fallback = GenSharedLabel();
@ -1333,19 +1369,32 @@ void EmitFPVectorMulAdd(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
code.movaps(result, xmm_a); code.movaps(result, xmm_a);
FCODE(vfmadd231p)(result, xmm_b, xmm_c); FCODE(vfmadd231p)(result, xmm_b, xmm_c);
code.movaps(tmp, GetNegativeZeroVector<fsize>(code)); if (needs_rounding_correction && needs_nan_correction) {
code.andnps(tmp, result); code.vandps(tmp, result, GetNonSignMaskVector<fsize>(code));
FCODE(vcmpeq_uqp)(tmp, tmp, GetSmallestNormalVector<fsize>(code)); FCODE(vcmpeq_uqp)(tmp, tmp, GetSmallestNormalVector<fsize>(code));
} else if (needs_rounding_correction) {
code.vandps(tmp, result, GetNonSignMaskVector<fsize>(code));
ICODE(vpcmpeq)(tmp, tmp, GetSmallestNormalVector<fsize>(code));
} else if (needs_nan_correction) {
FCODE(vcmpunordp)(tmp, result, result);
}
code.vptest(tmp, tmp); code.vptest(tmp, tmp);
code.jnz(*fallback, code.T_NEAR); code.jnz(*fallback, code.T_NEAR);
code.L(*end); code.L(*end);
ForceToDefaultNaN<fsize>(code, ctx.FPCR(fpcr_controlled), result);
}); });
ctx.deferred_emits.emplace_back([=, &code, &ctx] { ctx.deferred_emits.emplace_back([=, &code, &ctx] {
code.L(*fallback); code.L(*fallback);
code.sub(rsp, 8); code.sub(rsp, 8);
ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); ABI_PushCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
EmitFourOpFallbackWithoutRegAlloc(code, ctx, result, xmm_a, xmm_b, xmm_c, fallback_fn, fpcr_controlled); if (needs_rounding_correction && needs_nan_correction) {
EmitFourOpFallbackWithoutRegAlloc<LoadPreviousResult::Yes>(code, ctx, result, xmm_a, xmm_b, xmm_c, EmitFPVectorMulAddFallback<FPT, true, true>, fpcr_controlled);
} else if (needs_rounding_correction) {
EmitFourOpFallbackWithoutRegAlloc<LoadPreviousResult::Yes>(code, ctx, result, xmm_a, xmm_b, xmm_c, EmitFPVectorMulAddFallback<FPT, true, false>, fpcr_controlled);
} else if (needs_nan_correction) {
EmitFourOpFallbackWithoutRegAlloc<LoadPreviousResult::Yes>(code, ctx, result, xmm_a, xmm_b, xmm_c, EmitFPVectorMulAddFallback<FPT, false, true>, fpcr_controlled);
}
ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx())); ABI_PopCallerSaveRegistersAndAdjustStackExcept(code, HostLocXmmIdx(result.getIdx()));
code.add(rsp, 8); code.add(rsp, 8);
code.jmp(*end, code.T_NEAR); code.jmp(*end, code.T_NEAR);