diff --git a/src/backend_x64/emit_x64_vector.cpp b/src/backend_x64/emit_x64_vector.cpp index 9c1c8ce9..7f5377ef 100644 --- a/src/backend_x64/emit_x64_vector.cpp +++ b/src/backend_x64/emit_x64_vector.cpp @@ -4,6 +4,7 @@ * General Public License version 2 or any later version. */ +#include "backend_x64/abi.h" #include "backend_x64/block_of_code.h" #include "backend_x64/emit_x64.h" #include "common/assert.h" @@ -28,6 +29,31 @@ static void EmitVectorOperation(BlockOfCode& code, EmitContext& ctx, IR::Inst* i ctx.reg_alloc.DefineValue(inst, xmm_a); } +template +static void EmitTwoArgumentFallback(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, Lambda lambda) { + const auto fn = +lambda; // Force decay of lambda to function pointer + constexpr u32 stack_space = 3 * 16; + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + const Xbyak::Xmm arg1 = ctx.reg_alloc.UseXmm(args[0]); + const Xbyak::Xmm arg2 = ctx.reg_alloc.UseXmm(args[1]); + ctx.reg_alloc.EndOfAllocScope(); + + ctx.reg_alloc.HostCall(nullptr); + code.sub(rsp, stack_space + ABI_SHADOW_SPACE); + code.lea(code.ABI_PARAM1, ptr[rsp + ABI_SHADOW_SPACE + 0 * 16]); + code.lea(code.ABI_PARAM2, ptr[rsp + ABI_SHADOW_SPACE + 1 * 16]); + code.lea(code.ABI_PARAM3, ptr[rsp + ABI_SHADOW_SPACE + 2 * 16]); + + code.movaps(xword[code.ABI_PARAM2], arg1); + code.movaps(xword[code.ABI_PARAM3], arg2); + code.CallFunction(+fn); + code.movaps(xmm0, xword[rsp + ABI_SHADOW_SPACE + 0 * 16]); + + code.add(rsp, stack_space + ABI_SHADOW_SPACE); + + ctx.reg_alloc.DefineValue(inst, xmm0); +} + void EmitX64::EmitVectorGetElement8(EmitContext& ctx, IR::Inst* inst) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); ASSERT(args[1].IsImmediate()); @@ -575,6 +601,52 @@ void EmitX64::EmitVectorLogicalShiftRight64(EmitContext& ctx, IR::Inst* inst) { ctx.reg_alloc.DefineValue(inst, result); } +void EmitX64::EmitVectorMultiply8(EmitContext& ctx, IR::Inst* inst) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]); + Xbyak::Xmm b = ctx.reg_alloc.UseScratchXmm(args[1]); + Xbyak::Xmm tmp_a = ctx.reg_alloc.ScratchXmm(); + Xbyak::Xmm tmp_b = ctx.reg_alloc.ScratchXmm(); + + // TODO: Optimize + code.movdqa(tmp_a, a); + code.movdqa(tmp_b, b); + code.pmullw(a, b); + code.psrlw(tmp_a, 8); + code.psrlw(tmp_b, 8); + code.pmullw(tmp_a, tmp_b); + code.pand(a, code.MConst(0x00FF00FF00FF00FF, 0x00FF00FF00FF00FF)); + code.psllw(tmp_a, 8); + code.por(a, tmp_a); + + ctx.reg_alloc.DefineValue(inst, a); +} + +void EmitX64::EmitVectorMultiply16(EmitContext& ctx, IR::Inst* inst) { + EmitVectorOperation(code, ctx, inst, &Xbyak::CodeGenerator::pmullw); +} + +void EmitX64::EmitVectorMultiply32(EmitContext& ctx, IR::Inst* inst) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tSSE41)) { + EmitVectorOperation(code, ctx, inst, &Xbyak::CodeGenerator::pmulld); + return; + } + + EmitTwoArgumentFallback(code, ctx, inst, [](std::array& result, const std::array& a, const std::array& b){ + for (size_t i = 0; i < 4; ++i) { + result[i] = a[i] * b[i]; + } + }); +} + +void EmitX64::EmitVectorMultiply64(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](std::array& result, const std::array& a, const std::array& b){ + for (size_t i = 0; i < 2; ++i) { + result[i] = a[i] * b[i]; + } + }); +} + void EmitX64::EmitVectorNarrow16(EmitContext& ctx, IR::Inst* inst) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]); diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index c50e4b37..5c59c10c 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -913,6 +913,21 @@ U128 IREmitter::VectorLogicalShiftRight(size_t esize, const U128& a, u8 shift_am return {}; } +U128 IREmitter::VectorMultiply(size_t esize, const U128& a, const U128& b) { + switch (esize) { + case 8: + return Inst(Opcode::VectorMultiply8, a, b); + case 16: + return Inst(Opcode::VectorMultiply16, a, b); + case 32: + return Inst(Opcode::VectorMultiply32, a, b); + case 64: + return Inst(Opcode::VectorMultiply64, a, b); + } + UNREACHABLE(); + return {}; +} + U128 IREmitter::VectorNarrow(size_t original_esize, const U128& a) { switch (original_esize) { case 16: diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index 8de49c40..94ee4184 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -217,6 +217,7 @@ public: U128 VectorInterleaveLower(size_t esize, const U128& a, const U128& b); U128 VectorLogicalShiftLeft(size_t esize, const U128& a, u8 shift_amount); U128 VectorLogicalShiftRight(size_t esize, const U128& a, u8 shift_amount); + U128 VectorMultiply(size_t esize, const U128& a, const U128& b); U128 VectorNarrow(size_t original_esize, const U128& a); U128 VectorNot(const U128& a); U128 VectorOr(const U128& a, const U128& b); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 807f315f..467d7420 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -232,6 +232,10 @@ OPCODE(VectorLogicalShiftRight8, T::U128, T::U128, T::U8 OPCODE(VectorLogicalShiftRight16, T::U128, T::U128, T::U8 ) OPCODE(VectorLogicalShiftRight32, T::U128, T::U128, T::U8 ) OPCODE(VectorLogicalShiftRight64, T::U128, T::U128, T::U8 ) +OPCODE(VectorMultiply8, T::U128, T::U128, T::U128 ) +OPCODE(VectorMultiply16, T::U128, T::U128, T::U128 ) +OPCODE(VectorMultiply32, T::U128, T::U128, T::U128 ) +OPCODE(VectorMultiply64, T::U128, T::U128, T::U128 ) OPCODE(VectorNarrow16, T::U128, T::U128 ) OPCODE(VectorNarrow32, T::U128, T::U128 ) OPCODE(VectorNarrow64, T::U128, T::U128 )