From 7fdd8b01979ba32ca35c285da81419f2780e099f Mon Sep 17 00:00:00 2001 From: Lioncash Date: Thu, 26 Jul 2018 12:24:47 -0400 Subject: [PATCH] A64: Implement PMULL{2} --- src/backend_x64/emit_x64_vector.cpp | 36 ++++++++++++++++--- src/frontend/A64/decoder/a64.inc | 2 +- .../translate/impl/simd_three_different.cpp | 16 +++++++++ src/frontend/ir/ir_emitter.cpp | 12 +++++++ src/frontend/ir/ir_emitter.h | 1 + src/frontend/ir/opcodes.inc | 2 ++ 6 files changed, 64 insertions(+), 5 deletions(-) diff --git a/src/backend_x64/emit_x64_vector.cpp b/src/backend_x64/emit_x64_vector.cpp index 23c17d23..2b0d4252 100644 --- a/src/backend_x64/emit_x64_vector.cpp +++ b/src/backend_x64/emit_x64_vector.cpp @@ -1866,12 +1866,12 @@ void EmitX64::EmitVectorPairedAddUnsignedWiden32(EmitContext& ctx, IR::Inst* ins ctx.reg_alloc.DefineValue(inst, a); } -template -static T PolynomialMultiply(T lhs, T rhs) { +template +static D PolynomialMultiply(T lhs, T rhs) { constexpr size_t bit_size = Common::BitSize(); const std::bitset operand(lhs); - T res = 0; + D res = 0; for (size_t i = 0; i < bit_size; i++) { if (operand[i]) { res ^= rhs << i; @@ -1883,7 +1883,35 @@ static T PolynomialMultiply(T lhs, T rhs) { void EmitX64::EmitVectorPolynomialMultiply8(EmitContext& ctx, IR::Inst* inst) { EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& a, const VectorArray& b) { - std::transform(a.begin(), a.end(), b.begin(), result.begin(), PolynomialMultiply); + std::transform(a.begin(), a.end(), b.begin(), result.begin(), PolynomialMultiply); + }); +} + +void EmitX64::EmitVectorPolynomialMultiplyLong8(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& a, const VectorArray& b) { + for (size_t i = 0; i < result.size(); i++) { + result[i] = PolynomialMultiply(a[i], b[i]); + } + }); +} + +void EmitX64::EmitVectorPolynomialMultiplyLong64(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& a, const VectorArray& b) { + const auto handle_high_bits = [](u64 lhs, u64 rhs) { + constexpr size_t bit_size = Common::BitSize(); + u64 result = 0; + + for (size_t i = 1; i < bit_size; i++) { + if (Common::Bit(i, lhs)) { + result ^= rhs >> (bit_size - i); + } + } + + return result; + }; + + result[0] = PolynomialMultiply(a[0], b[0]); + result[1] = handle_high_bits(a[0], b[0]); }); } diff --git a/src/frontend/A64/decoder/a64.inc b/src/frontend/A64/decoder/a64.inc index b07bc723..ac2b50f6 100644 --- a/src/frontend/A64/decoder/a64.inc +++ b/src/frontend/A64/decoder/a64.inc @@ -687,7 +687,7 @@ INST(SABDL, "SABDL, SABDL2", "0Q001 INST(SMLAL_vec, "SMLAL, SMLAL2 (vector)", "0Q001110zz1mmmmm100000nnnnnddddd") INST(SMLSL_vec, "SMLSL, SMLSL2 (vector)", "0Q001110zz1mmmmm101000nnnnnddddd") INST(SMULL_vec, "SMULL, SMULL2 (vector)", "0Q001110zz1mmmmm110000nnnnnddddd") -//INST(PMULL, "PMULL, PMULL2", "0Q001110zz1mmmmm111000nnnnnddddd") +INST(PMULL, "PMULL, PMULL2", "0Q001110zz1mmmmm111000nnnnnddddd") INST(UADDL, "UADDL, UADDL2", "0Q101110zz1mmmmm000000nnnnnddddd") INST(UADDW, "UADDW, UADDW2", "0Q101110zz1mmmmm000100nnnnnddddd") INST(USUBL, "USUBL, USUBL2", "0Q101110zz1mmmmm001000nnnnnddddd") diff --git a/src/frontend/A64/translate/impl/simd_three_different.cpp b/src/frontend/A64/translate/impl/simd_three_different.cpp index 4df39d58..01f8dffa 100644 --- a/src/frontend/A64/translate/impl/simd_three_different.cpp +++ b/src/frontend/A64/translate/impl/simd_three_different.cpp @@ -161,6 +161,22 @@ bool WideOperation(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vm, Vec Vn, Ve } } // Anonymous namespace +bool TranslatorVisitor::PMULL(bool Q, Imm<2> size, Vec Vm, Vec Vn, Vec Vd) { + if (size == 0b01 || size == 0b10) { + return ReservedValue(); + } + + const size_t esize = 8 << size.ZeroExtend(); + const size_t datasize = 64; + + const IR::U128 operand1 = Vpart(datasize, Vn, Q); + const IR::U128 operand2 = Vpart(datasize, Vm, Q); + const IR::U128 result = ir.VectorPolynomialMultiplyLong(esize, operand1, operand2); + + V(128, Vd, result); + return true; +} + bool TranslatorVisitor::SABAL(bool Q, Imm<2> size, Vec Vm, Vec Vn, Vec Vd) { return AbsoluteDifferenceLong(*this, Q, size, Vm, Vn, Vd, AbsoluteDifferenceBehavior::Accumulate, Signedness::Signed); } diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 120775a6..11b9ae64 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -1198,6 +1198,18 @@ U128 IREmitter::VectorPolynomialMultiply(const U128& a, const U128& b) { return Inst(Opcode::VectorPolynomialMultiply8, a, b); } +U128 IREmitter::VectorPolynomialMultiplyLong(size_t esize, const U128& a, const U128& b) { + switch (esize) { + case 8: + return Inst(Opcode::VectorPolynomialMultiplyLong8, a, b); + case 64: + return Inst(Opcode::VectorPolynomialMultiplyLong64, a, b); + default: + UNREACHABLE(); + return {}; + } +} + U128 IREmitter::VectorPopulationCount(const U128& a) { return Inst(Opcode::VectorPopulationCount, a); } diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index 5e66712a..27480fa0 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -239,6 +239,7 @@ public: U128 VectorPairedAddSignedWiden(size_t original_esize, const U128& a); U128 VectorPairedAddUnsignedWiden(size_t original_esize, const U128& a); U128 VectorPolynomialMultiply(const U128& a, const U128& b); + U128 VectorPolynomialMultiplyLong(size_t esize, const U128& a, const U128& b); U128 VectorPopulationCount(const U128& a); U128 VectorReverseBits(const U128& a); U128 VectorRotateLeft(size_t esize, const U128& a, u8 amount); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index 2e2534b2..c41cf7bd 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -331,6 +331,8 @@ OPCODE(VectorPairedAdd16, T::U128, T::U128, OPCODE(VectorPairedAdd32, T::U128, T::U128, T::U128 ) OPCODE(VectorPairedAdd64, T::U128, T::U128, T::U128 ) OPCODE(VectorPolynomialMultiply8, T::U128, T::U128, T::U128 ) +OPCODE(VectorPolynomialMultiplyLong8, T::U128, T::U128, T::U128 ) +OPCODE(VectorPolynomialMultiplyLong64, T::U128, T::U128, T::U128 ) OPCODE(VectorPopulationCount, T::U128, T::U128 ) OPCODE(VectorReverseBits, T::U128, T::U128 ) OPCODE(VectorRoundingHalvingAddS8, T::U128, T::U128, T::U128 )