diff --git a/src/backend_x64/emit_x64_vector_floating_point.cpp b/src/backend_x64/emit_x64_vector_floating_point.cpp index c312c637..53a8ac15 100644 --- a/src/backend_x64/emit_x64_vector_floating_point.cpp +++ b/src/backend_x64/emit_x64_vector_floating_point.cpp @@ -5,12 +5,14 @@ */ #include +#include #include #include #include "backend_x64/abi.h" #include "backend_x64/block_of_code.h" #include "backend_x64/emit_x64.h" +#include "common/assert.h" #include "common/bit_util.h" #include "common/fp/fpcr.h" #include "common/fp/info.h" @@ -26,7 +28,7 @@ namespace Dynarmic::BackendX64 { using namespace Xbyak::util; namespace mp = Common::mp; -template +template static T ChooseOnFsize(T f32, T f64) { static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); @@ -39,12 +41,12 @@ static T ChooseOnFsize(T f32, T f64) { #define FCODE(NAME) (code.*ChooseOnFsize(&Xbyak::CodeGenerator::NAME##s, &Xbyak::CodeGenerator::NAME##d)) -template +template class Indexer, size_t... argi> static auto GetRuntimeNaNFunction(std::index_sequence) { auto result = [](std::array, sizeof...(argi) + 1>& values) { VectorArray& result = values[0]; for (size_t elementi = 0; elementi < result.size(); ++elementi) { - const auto current_values = IndexFunction(elementi, values[argi + 1]...); + const auto current_values = Indexer{}(elementi, values[argi + 1]...); if (auto r = FP::ProcessNaNs(std::get(current_values)...)) { result[elementi] = *r; } else if (FP::IsNaN(result[elementi])) { @@ -55,7 +57,7 @@ static auto GetRuntimeNaNFunction(std::index_sequence) { return static_cast*>(result); } -template +template class Indexer> static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array xmms, const Xbyak::Xmm& nan_mask) { static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); @@ -90,7 +92,7 @@ static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array; - code.CallFunction(GetRuntimeNaNFunction(std::make_index_sequence{})); + code.CallFunction(GetRuntimeNaNFunction(std::make_index_sequence{})); code.movaps(result, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]); code.add(rsp, stack_space + ABI_SHADOW_SPACE); @@ -100,41 +102,59 @@ static void HandleNaNs(BlockOfCode& code, EmitContext& ctx, std::array DefaultIndexFunction32(size_t i, const VectorArray& a, const VectorArray& b) { - return std::make_tuple(a[i], b[i]); -} - -static std::tuple DefaultIndexFunction64(size_t i, const VectorArray& a, const VectorArray& b) { - return std::make_tuple(a[i], b[i]); -} - -static std::tuple PairedIndexFunction32(size_t i, const VectorArray& a, const VectorArray& b) { - if (i < 2) { - return std::make_tuple(a[2 * i], a[2 * i + 1]); +template +struct DefaultIndexer { + std::tuple operator()(size_t i, const VectorArray& a, const VectorArray& b) { + return std::make_tuple(a[i], b[i]); } - return std::make_tuple(b[2 * (i - 2)], b[2 * (i - 2) + 1]); -} -static std::tuple PairedIndexFunction64(size_t i, const VectorArray& a, const VectorArray& b) { - return i == 0 ? std::make_tuple(a[0], a[1]) : std::make_tuple(b[0], b[1]); -} - -static std::tuple PairedLowerIndexFunction32(size_t i, const VectorArray& a, const VectorArray& b) { - switch (i) { - case 0: - return std::make_tuple(a[0], a[1]); - case 1: - return std::make_tuple(b[0], b[1]); - default: - return std::make_tuple(u32(0), u32(0)); + std::tuple operator()(size_t i, const VectorArray& a, const VectorArray& b, const VectorArray& c) { + return std::make_tuple(a[i], b[i], c[i]); } -} +}; -static std::tuple PairedLowerIndexFunction64(size_t i, const VectorArray& a, const VectorArray& b) { - return i == 0 ? std::make_tuple(a[0], b[0]) : std::make_tuple(u64(0), u64(0)); -} +template +struct PairedIndexer { + std::tuple operator()(size_t i, const VectorArray& a, const VectorArray& b) { + constexpr size_t halfway = std::tuple_size_v> / 2; + const size_t which_array = i / halfway; + i %= halfway; + switch (which_array) { + case 0: + return std::make_tuple(a[2 * i], a[2 * i + 1]); + case 1: + return std::make_tuple(b[2 * i], b[2 * i + 1]); + } + UNREACHABLE(); + return {}; + } +}; -template +template +struct PairedLowerIndexer { + std::tuple operator()(size_t i, const VectorArray& a, const VectorArray& b) { + constexpr size_t array_size = std::tuple_size_v>; + if constexpr (array_size == 4) { + switch (i) { + case 0: + return std::make_tuple(a[0], a[1]); + case 1: + return std::make_tuple(b[0], b[1]); + default: + return std::make_tuple(0, 0); + } + } else if constexpr (array_size == 2) { + if (i == 0) { + return std::make_tuple(a[0], b[0]); + } + return std::make_tuple(0, 0); + } + UNREACHABLE(); + return {}; + } +}; + +template class Indexer, typename Function> static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Function fn) { static_assert(fsize == 32 || fsize == 64, "fsize must be either 32 or 64"); @@ -182,7 +202,7 @@ static void EmitThreeOpVectorOperation(BlockOfCode& code, EmitContext& ctx, IR:: } FCODE(cmpunordp)(nan_mask, result); - HandleNaNs(code, ctx, {result, xmm_a, xmm_b}, nan_mask); + HandleNaNs(code, ctx, {result, xmm_a, xmm_b}, nan_mask); ctx.reg_alloc.DefineValue(inst, result); } @@ -212,7 +232,7 @@ inline void EmitOneArgumentFallback(BlockOfCode& code, EmitContext& ctx, IR::Ins ctx.reg_alloc.DefineValue(inst, xmm0); } -template +template inline void EmitTwoArgumentFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Lambda lambda) { const auto fn = static_cast*>(lambda); @@ -290,19 +310,19 @@ void EmitX64::EmitFPVectorAbs64(EmitContext& ctx, IR::Inst* inst) { } void EmitX64::EmitFPVectorAdd32(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<32, DefaultIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::addps); + EmitThreeOpVectorOperation<32, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::addps); } void EmitX64::EmitFPVectorAdd64(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<64, DefaultIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::addpd); + EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::addpd); } void EmitX64::EmitFPVectorDiv32(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<32, DefaultIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::divps); + EmitThreeOpVectorOperation<32, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::divps); } void EmitX64::EmitFPVectorDiv64(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<64, DefaultIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::divpd); + EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::divpd); } void EmitX64::EmitFPVectorEqual32(EmitContext& ctx, IR::Inst* inst) { @@ -366,23 +386,23 @@ void EmitX64::EmitFPVectorGreaterEqual64(EmitContext& ctx, IR::Inst* inst) { } void EmitX64::EmitFPVectorMul32(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<32, DefaultIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::mulps); + EmitThreeOpVectorOperation<32, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulps); } void EmitX64::EmitFPVectorMul64(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<64, DefaultIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd); + EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::mulpd); } void EmitX64::EmitFPVectorPairedAdd32(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<32, PairedIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::haddps); + EmitThreeOpVectorOperation<32, PairedIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::haddps); } void EmitX64::EmitFPVectorPairedAdd64(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<64, PairedIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::haddpd); + EmitThreeOpVectorOperation<64, PairedIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::haddpd); } void EmitX64::EmitFPVectorPairedAddLower32(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<32, PairedLowerIndexFunction32>(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm xmm_b) { + EmitThreeOpVectorOperation<32, PairedLowerIndexer>(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm xmm_b) { const Xbyak::Xmm zero = ctx.reg_alloc.ScratchXmm(); code.xorps(zero, zero); code.punpcklqdq(result, xmm_b); @@ -391,7 +411,7 @@ void EmitX64::EmitFPVectorPairedAddLower32(EmitContext& ctx, IR::Inst* inst) { } void EmitX64::EmitFPVectorPairedAddLower64(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<64, PairedLowerIndexFunction64>(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm xmm_b) { + EmitThreeOpVectorOperation<64, PairedLowerIndexer>(code, ctx, inst, [&](Xbyak::Xmm result, Xbyak::Xmm xmm_b) { const Xbyak::Xmm zero = ctx.reg_alloc.ScratchXmm(); code.xorps(zero, zero); code.punpcklqdq(result, xmm_b); @@ -484,11 +504,11 @@ void EmitX64::EmitFPVectorS64ToDouble(EmitContext& ctx, IR::Inst* inst) { } void EmitX64::EmitFPVectorSub32(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<32, DefaultIndexFunction32>(code, ctx, inst, &Xbyak::CodeGenerator::subps); + EmitThreeOpVectorOperation<32, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::subps); } void EmitX64::EmitFPVectorSub64(EmitContext& ctx, IR::Inst* inst) { - EmitThreeOpVectorOperation<64, DefaultIndexFunction64>(code, ctx, inst, &Xbyak::CodeGenerator::subpd); + EmitThreeOpVectorOperation<64, DefaultIndexer>(code, ctx, inst, &Xbyak::CodeGenerator::subpd); } void EmitX64::EmitFPVectorU32ToSingle(EmitContext& ctx, IR::Inst* inst) {