diff --git a/src/backend_x64/emit_x64_floating_point.cpp b/src/backend_x64/emit_x64_floating_point.cpp index cbd85ccc..29dd100b 100644 --- a/src/backend_x64/emit_x64_floating_point.cpp +++ b/src/backend_x64/emit_x64_floating_point.cpp @@ -207,6 +207,152 @@ Xbyak::Label ProcessNaN(BlockOfCode& code, Xbyak::Xmm a) { return end; } +// This is necessary because x86 and ARM differ in they way they return NaNs from floating point operations +// +// ARM behaviour: +// op1 op2 result +// SNaN SNaN/QNaN op1 +// QNaN SNaN op2 +// QNaN QNaN op1 +// SNaN/QNaN other op1 +// other SNaN/QNaN op2 +// +// x86 behaviour: +// op1 op2 result +// SNaN/QNaN SNaN/QNaN op1 +// SNaN/QNaN other op1 +// other SNaN/QNaN op2 +// +// With ARM: SNaNs take priority. With x86: it doesn't matter. +// +// From the above we can see what differs between the architectures is +// the case when op1 == QNaN and op2 == SNaN. +// +// We assume that registers op1 and op2 are read-only. This function also trashes xmm0. +// We allow for the case where op1 and result are the same register. We do not read from op1 once result is written to. +template +void EmitPostProcessNaNs(BlockOfCode& code, Xbyak::Xmm result, Xbyak::Xmm op1, Xbyak::Xmm op2, Xbyak::Reg64 tmp, Xbyak::Label end) { + using FPT = mp::unsigned_integer_of_size; + constexpr FPT exponent_mask = FP::FPInfo::exponent_mask; + constexpr FPT mantissa_msb = FP::FPInfo::mantissa_msb; + constexpr u8 mantissa_msb_bit = static_cast(FP::FPInfo::explicit_mantissa_width - 1); + + // At this point we know that at least one of op1 and op2 is a NaN. + // Thus in op1 ^ op2 at least one of the two would have all 1 bits in the exponent. + // Keeping in mind xor is commutative, there are only four cases: + // SNaN ^ SNaN/Inf -> exponent == 0, mantissa_msb == 0 + // QNaN ^ QNaN -> exponent == 0, mantissa_msb == 0 + // QNaN ^ SNaN/Inf -> exponent == 0, mantissa_msb == 1 + // SNaN/QNaN ^ Otherwise -> exponent != 0, mantissa_msb == ? + // + // We're only really interested in op1 == QNaN and op2 == SNaN, + // so we filter out everything else. + // + // We do it this way instead of checking that op1 is QNaN because + // op1 == QNaN && op2 == QNaN is the most common case. With this method + // that case would only require one branch. + + if (code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + code.vxorps(xmm0, op1, op2); + } else { + code.movaps(xmm0, op1); + code.xorps(xmm0, op2); + } + + constexpr size_t shift = fsize == 32 ? 0 : 48; + if constexpr (fsize == 32) { + code.movd(tmp.cvt32(), xmm0); + } else { + // We do this to avoid requiring 64-bit immediates + code.pextrw(tmp.cvt32(), xmm0, shift / 16); + } + code.and_(tmp.cvt32(), static_cast((exponent_mask | mantissa_msb) >> shift)); + code.cmp(tmp.cvt32(), static_cast(mantissa_msb >> shift)); + code.jne(end, code.T_NEAR); + + // If we're here there are four cases left: + // op1 == SNaN && op2 == QNaN + // op1 == Inf && op2 == QNaN + // op1 == QNaN && op2 == SNaN <<< The problematic case + // op1 == QNaN && op2 == Inf + + if constexpr (fsize == 32) { + code.movd(tmp.cvt32(), op2); + code.shl(tmp.cvt32(), 32 - mantissa_msb_bit); + } else { + code.movq(tmp, op2); + code.shl(tmp, 64 - mantissa_msb_bit); + } + // If op2 is a SNaN, CF = 0 and ZF = 0. + code.jna(end, code.T_NEAR); + + // Silence the SNaN as required by spec. + if (code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + code.vorps(result, op2, code.MConst(xword, mantissa_msb)); + } else { + code.movaps(result, op2); + code.orps(result, code.MConst(xword, mantissa_msb)); + } + code.jmp(end, code.T_NEAR); +} + +// Do full NaN processing. +template +void EmitProcessNaNs(BlockOfCode& code, Xbyak::Xmm result, Xbyak::Xmm op1, Xbyak::Xmm op2, Xbyak::Reg64 tmp, Xbyak::Label end) { + using FPT = mp::unsigned_integer_of_size; + constexpr FPT exponent_mask = FP::FPInfo::exponent_mask; + constexpr FPT mantissa_msb = FP::FPInfo::mantissa_msb; + constexpr u8 mantissa_msb_bit = static_cast(FP::FPInfo::explicit_mantissa_width - 1); + + Xbyak::Label return_sum; + + if (code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + code.vxorps(xmm0, op1, op2); + } else { + code.movaps(xmm0, op1); + code.xorps(xmm0, op2); + } + + constexpr size_t shift = fsize == 32 ? 0 : 48; + if constexpr (fsize == 32) { + code.movd(tmp.cvt32(), xmm0); + } else { + code.pextrw(tmp.cvt32(), xmm0, shift / 16); + } + code.and_(tmp.cvt32(), static_cast((exponent_mask | mantissa_msb) >> shift)); + code.cmp(tmp.cvt32(), static_cast(mantissa_msb >> shift)); + code.jne(return_sum); + + if constexpr (fsize == 32) { + code.movd(tmp.cvt32(), op2); + code.shl(tmp.cvt32(), 32 - mantissa_msb_bit); + } else { + code.movq(tmp, op2); + code.shl(tmp, 64 - mantissa_msb_bit); + } + code.jna(return_sum); + + if (code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + code.vorps(result, op2, code.MConst(xword, mantissa_msb)); + } else { + code.movaps(result, op2); + code.orps(result, code.MConst(xword, mantissa_msb)); + } + code.jmp(end, code.T_NEAR); + + // x86 behaviour is reliable in this case + code.L(return_sum); + if (result == op1) { + FCODE(adds)(result, op2); + } else if (code.DoesCpuSupport(Xbyak::util::Cpu::tAVX)) { + FCODE(vadds)(result, op1, op2); + } else { + code.movaps(result, op1); + FCODE(adds)(result, op2); + } + code.jmp(end, code.T_NEAR); +} + template void FPTwoOp(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); @@ -274,8 +420,58 @@ void FPThreeOp(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, [[maybe_unus } template -void FPThreeOp(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn, CallDenormalsAreZero call_denormals_are_zero = CallDenormalsAreZero::No) { - FPThreeOp(code, ctx, inst, nullptr, fn, call_denormals_are_zero); +void FPThreeOp(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn) { + using FPT = mp::unsigned_integer_of_size; + + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + if (ctx.FPSCR_DN() || !ctx.AccurateNaN()) { + const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm operand = ctx.reg_alloc.UseScratchXmm(args[1]); + + if constexpr (std::is_member_function_pointer_v) { + (code.*fn)(result, operand); + } else { + fn(result, operand); + } + + if (ctx.AccurateNaN()) { + ForceToDefaultNaN(code, result); + } + + ctx.reg_alloc.DefineValue(inst, result); + return; + } + + const Xbyak::Xmm op1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm op2 = ctx.reg_alloc.UseXmm(args[1]); + const Xbyak::Xmm result = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Reg64 tmp = ctx.reg_alloc.ScratchGpr(); + + Xbyak::Label end, nan, op_are_nans; + + code.movaps(result, op1); + if constexpr (std::is_member_function_pointer_v) { + (code.*fn)(result, op2); + } else { + fn(result, op2); + } + FCODE(ucomis)(result, result); + code.jp(nan, code.T_NEAR); + code.L(end); + + code.SwitchToFarCode(); + code.L(nan); + FCODE(ucomis)(op1, op2); + code.jp(op_are_nans); + // Here we must return a positive NaN, because the indefinite value on x86 is a negative NaN! + code.movaps(result, code.MConst(xword, FP::FPInfo::DefaultNaN())); + code.jmp(end, code.T_NEAR); + code.L(op_are_nans); + EmitPostProcessNaNs(code, result, op1, op2, tmp, end); + code.SwitchToNearCode(); + + ctx.reg_alloc.DefineValue(inst, result); } } // anonymous namespace @@ -334,53 +530,42 @@ void EmitX64::EmitFPDiv64(EmitContext& ctx, IR::Inst* inst) { template static void EmitFPMax(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { - if (ctx.FPSCR_DN()) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); - const Xbyak::Xmm operand = ctx.reg_alloc.UseScratchXmm(args[1]); + const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm operand = ctx.reg_alloc.UseScratchXmm(args[1]); + const Xbyak::Reg64 gpr_scratch = ctx.reg_alloc.ScratchGpr(); - if (ctx.FPSCR_FTZ()) { - const Xbyak::Reg64 gpr_scratch = ctx.reg_alloc.ScratchGpr(); - DenormalsAreZero(code, result, gpr_scratch); - DenormalsAreZero(code, operand, gpr_scratch); - } - - Xbyak::Label equal, end, nan; - - FCODE(ucomis)(result, operand); - code.jz(equal, code.T_NEAR); - FCODE(maxs)(result, operand); - code.L(end); - - code.SwitchToFarCode(); - - code.L(equal); - code.jp(nan); - code.andps(result, operand); - code.jmp(end); - - code.L(nan); - code.movaps(result, code.MConst(xword, fsize == 32 ? f32_nan : f64_nan)); - code.jmp(end); - - code.SwitchToNearCode(); - - ctx.reg_alloc.DefineValue(inst, result); - - return; + if (ctx.FPSCR_FTZ()) { + DenormalsAreZero(code, result, gpr_scratch); + DenormalsAreZero(code, operand, gpr_scratch); } - FPThreeOp(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm operand){ - Xbyak::Label equal, end; - FCODE(ucomis)(result, operand); - code.jz(equal); - FCODE(maxs)(result, operand); + Xbyak::Label equal, end, nan; + + FCODE(ucomis)(result, operand); + code.jz(equal, code.T_NEAR); + FCODE(maxs)(result, operand); + code.L(end); + + code.SwitchToFarCode(); + + code.L(equal); + code.jp(nan); + code.andps(result, operand); + code.jmp(end); + + code.L(nan); + if (ctx.FPSCR_DN() || !ctx.AccurateNaN()) { + code.movaps(result, code.MConst(xword, fsize == 32 ? f32_nan : f64_nan)); code.jmp(end); - code.L(equal); - code.andps(result, operand); - code.L(end); - }, CallDenormalsAreZero::Yes); + } else { + EmitProcessNaNs(code, result, result, operand, gpr_scratch, end); + } + + code.SwitchToNearCode(); + + ctx.reg_alloc.DefineValue(inst, result); } void EmitX64::EmitFPMax32(EmitContext& ctx, IR::Inst* inst) { @@ -461,53 +646,42 @@ void EmitX64::EmitFPMaxNumeric64(EmitContext& ctx, IR::Inst* inst) { template static void EmitFPMin(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) { - if (ctx.FPSCR_DN()) { - auto args = ctx.reg_alloc.GetArgumentInfo(inst); + auto args = ctx.reg_alloc.GetArgumentInfo(inst); - const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); - const Xbyak::Xmm operand = ctx.reg_alloc.UseScratchXmm(args[1]); + const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm operand = ctx.reg_alloc.UseScratchXmm(args[1]); + const Xbyak::Reg64 gpr_scratch = ctx.reg_alloc.ScratchGpr(); - if (ctx.FPSCR_FTZ()) { - const Xbyak::Reg64 gpr_scratch = ctx.reg_alloc.ScratchGpr(); - DenormalsAreZero(code, result, gpr_scratch); - DenormalsAreZero(code, operand, gpr_scratch); - } - - Xbyak::Label equal, end, nan; - - FCODE(ucomis)(result, operand); - code.jz(equal, code.T_NEAR); - FCODE(mins)(result, operand); - code.L(end); - - code.SwitchToFarCode(); - - code.L(equal); - code.jp(nan); - code.orps(result, operand); - code.jmp(end); - - code.L(nan); - code.movaps(result, code.MConst(xword, fsize == 32 ? f32_nan : f64_nan)); - code.jmp(end); - - code.SwitchToNearCode(); - - ctx.reg_alloc.DefineValue(inst, result); - - return; + if (ctx.FPSCR_FTZ()) { + DenormalsAreZero(code, result, gpr_scratch); + DenormalsAreZero(code, operand, gpr_scratch); } - FPThreeOp(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm operand){ - Xbyak::Label equal, end; - FCODE(ucomis)(result, operand); - code.jz(equal); - FCODE(mins)(result, operand); + Xbyak::Label equal, end, nan; + + FCODE(ucomis)(result, operand); + code.jz(equal, code.T_NEAR); + FCODE(mins)(result, operand); + code.L(end); + + code.SwitchToFarCode(); + + code.L(equal); + code.jp(nan); + code.orps(result, operand); + code.jmp(end); + + code.L(nan); + if (ctx.FPSCR_DN()) { + code.movaps(result, code.MConst(xword, fsize == 32 ? f32_nan : f64_nan)); code.jmp(end); - code.L(equal); - code.orps(result, operand); - code.L(end); - }, CallDenormalsAreZero::Yes); + } else { + EmitProcessNaNs(code, result, result, operand, gpr_scratch, end); + } + + code.SwitchToNearCode(); + + ctx.reg_alloc.DefineValue(inst, result); } void EmitX64::EmitFPMin32(EmitContext& ctx, IR::Inst* inst) {