emit_x64_floating_point: AVX512 implementation of EmitFPMinMaxNumeric

This commit is contained in:
Wunkolo 2021-06-13 17:51:43 -07:00 committed by merry
parent a1192a51d8
commit c6125082ea

View file

@ -444,97 +444,112 @@ static void EmitFPMinMax(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
template<size_t fsize, bool is_max>
static void EmitFPMinMaxNumeric(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst) {
using FPT = mp::unsigned_integer_of_size<fsize>;
constexpr FPT default_nan = FP::FPInfo<FPT>::DefaultNaN();
constexpr u8 mantissa_msb_bit = static_cast<u8>(FP::FPInfo<FPT>::explicit_mantissa_width - 1);
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm op1 = ctx.reg_alloc.UseScratchXmm(args[0]);
const Xbyak::Xmm op2 = ctx.reg_alloc.UseScratchXmm(args[1]); // Result stored here!
Xbyak::Reg tmp = ctx.reg_alloc.ScratchGpr();
tmp.setBit(fsize);
const auto move_to_tmp = [&](const Xbyak::Xmm& xmm) {
if constexpr (fsize == 32) {
code.movd(tmp.cvt32(), xmm);
} else {
code.movq(tmp.cvt64(), xmm);
}
};
Xbyak::Label end, z, nan, op2_is_nan, snan, maybe_both_nan, normal;
DenormalsAreZero<fsize>(code, ctx, {op1, op2});
FCODE(ucomis)(op1, op2);
code.jz(z, code.T_NEAR);
code.L(normal);
if constexpr (is_max) {
FCODE(maxs)(op2, op1);
if (code.HasHostFeature(HostFeature::AVX512_OrthoFloat)) {
// vrangep{s,d} will already correctly handle comparing
// signed zeros and propagating NaNs similar to ARM
constexpr FpRangeSelect range_select = is_max ? FpRangeSelect::Max : FpRangeSelect::Min;
FCODE(vranges)(op2, op1, op2, FpRangeLUT(range_select, FpRangeSign::Preserve));
if (ctx.FPCR().DN()) {
FCODE(vcmps)(k1, op2, op2, Cmp::Unordered_Q);
FCODE(vmovs)(op2 | k1, code.MConst(xword, default_nan));
}
} else {
FCODE(mins)(op2, op1);
}
code.L(end);
Xbyak::Reg tmp = ctx.reg_alloc.ScratchGpr();
tmp.setBit(fsize);
code.SwitchToFarCode();
const auto move_to_tmp = [&](const Xbyak::Xmm& xmm) {
if constexpr (fsize == 32) {
code.movd(tmp.cvt32(), xmm);
} else {
code.movq(tmp.cvt64(), xmm);
}
};
code.L(z);
code.jp(nan);
if constexpr (is_max) {
code.andps(op2, op1);
} else {
code.orps(op2, op1);
}
code.jmp(end);
Xbyak::Label end, z, nan, op2_is_nan, snan, maybe_both_nan, normal;
// NaN requirements:
// op1 op2 result
// SNaN anything op1
// !SNaN SNaN op2
// QNaN !NaN op2
// !NaN QNaN op1
// QNaN QNaN op1
FCODE(ucomis)(op1, op2);
code.jz(z, code.T_NEAR);
code.L(normal);
if constexpr (is_max) {
FCODE(maxs)(op2, op1);
} else {
FCODE(mins)(op2, op1);
}
code.L(end);
code.L(nan);
FCODE(ucomis)(op1, op1);
code.jnp(op2_is_nan);
code.SwitchToFarCode();
// op1 is NaN
move_to_tmp(op1);
code.bt(tmp, mantissa_msb_bit);
code.jc(maybe_both_nan);
if (ctx.FPCR().DN()) {
code.L(snan);
code.movaps(op2, code.MConst(xword, FP::FPInfo<FPT>::DefaultNaN()));
code.L(z);
code.jp(nan);
if constexpr (is_max) {
code.andps(op2, op1);
} else {
code.orps(op2, op1);
}
code.jmp(end);
} else {
code.movaps(op2, op1);
code.L(snan);
code.orps(op2, code.MConst(xword, FP::FPInfo<FPT>::mantissa_msb));
code.jmp(end);
}
code.L(maybe_both_nan);
FCODE(ucomis)(op2, op2);
code.jnp(end, code.T_NEAR);
if (ctx.FPCR().DN()) {
code.jmp(snan);
} else {
// NaN requirements:
// op1 op2 result
// SNaN anything op1
// !SNaN SNaN op2
// QNaN !NaN op2
// !NaN QNaN op1
// QNaN QNaN op1
code.L(nan);
FCODE(ucomis)(op1, op1);
code.jnp(op2_is_nan);
// op1 is NaN
move_to_tmp(op1);
code.bt(tmp, mantissa_msb_bit);
code.jc(maybe_both_nan);
if (ctx.FPCR().DN()) {
code.L(snan);
code.movaps(op2, code.MConst(xword, default_nan));
code.jmp(end);
} else {
code.movaps(op2, op1);
code.L(snan);
code.orps(op2, code.MConst(xword, FP::FPInfo<FPT>::mantissa_msb));
code.jmp(end);
}
code.L(maybe_both_nan);
FCODE(ucomis)(op2, op2);
code.jnp(end, code.T_NEAR);
if (ctx.FPCR().DN()) {
code.jmp(snan);
} else {
move_to_tmp(op2);
code.bt(tmp.cvt64(), mantissa_msb_bit);
code.jnc(snan);
code.movaps(op2, op1);
code.jmp(end);
}
// op2 is NaN
code.L(op2_is_nan);
move_to_tmp(op2);
code.bt(tmp.cvt64(), mantissa_msb_bit);
code.bt(tmp, mantissa_msb_bit);
code.jnc(snan);
code.movaps(op2, op1);
code.jmp(end);
code.SwitchToNearCode();
}
// op2 is NaN
code.L(op2_is_nan);
move_to_tmp(op2);
code.bt(tmp, mantissa_msb_bit);
code.jnc(snan);
code.movaps(op2, op1);
code.jmp(end);
code.SwitchToNearCode();
ctx.reg_alloc.DefineValue(inst, op2);
}