diff --git a/src/backend_x64/emit_x64_vector.cpp b/src/backend_x64/emit_x64_vector.cpp index 26557ea9..8a3e8e94 100644 --- a/src/backend_x64/emit_x64_vector.cpp +++ b/src/backend_x64/emit_x64_vector.cpp @@ -2326,6 +2326,88 @@ void EmitX64::EmitVectorRoundingHalvingAddU32(EmitContext& ctx, IR::Inst* inst) EmitVectorRoundingHalvingAddUnsigned(32, ctx, inst, code); } +template +static void RoundingShiftLeft(VectorArray& out, const VectorArray& lhs, const VectorArray& rhs) { + using signed_type = std::make_signed_t; + using unsigned_type = std::make_unsigned_t; + + constexpr auto bit_size = static_cast(Common::BitSize()); + + for (size_t i = 0; i < out.size(); i++) { + const s64 extended_shift = Common::SignExtend<8>(rhs[i] & 0xFF); + + if (extended_shift >= 0) { + if (extended_shift >= bit_size) { + out[i] = 0; + } else { + out[i] = static_cast(static_cast(lhs[i]) << extended_shift); + } + } else { + if ((std::is_unsigned_v && extended_shift < -bit_size) || + (std::is_signed_v && extended_shift <= -bit_size)) { + out[i] = 0; + } else { + const s64 shift_value = -extended_shift - 1; + const T shifted = (lhs[i] & (static_cast(1) << shift_value)) >> shift_value; + + if (extended_shift == -bit_size) { + out[i] = shifted; + } else { + out[i] = (lhs[i] >> -extended_shift) + shifted; + } + } + } + } +} + +void EmitX64::EmitVectorRoundingShiftLeftS8(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& lhs, const VectorArray& rhs) { + RoundingShiftLeft(result, lhs, rhs); + }); +} + +void EmitX64::EmitVectorRoundingShiftLeftS16(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& lhs, const VectorArray& rhs) { + RoundingShiftLeft(result, lhs, rhs); + }); +} + +void EmitX64::EmitVectorRoundingShiftLeftS32(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& lhs, const VectorArray& rhs) { + RoundingShiftLeft(result, lhs, rhs); + }); +} + +void EmitX64::EmitVectorRoundingShiftLeftS64(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& lhs, const VectorArray& rhs) { + RoundingShiftLeft(result, lhs, rhs); + }); +} + +void EmitX64::EmitVectorRoundingShiftLeftU8(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& lhs, const VectorArray& rhs) { + RoundingShiftLeft(result, lhs, rhs); + }); +} + +void EmitX64::EmitVectorRoundingShiftLeftU16(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& lhs, const VectorArray& rhs) { + RoundingShiftLeft(result, lhs, rhs); + }); +} + +void EmitX64::EmitVectorRoundingShiftLeftU32(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& lhs, const VectorArray& rhs) { + RoundingShiftLeft(result, lhs, rhs); + }); +} + +void EmitX64::EmitVectorRoundingShiftLeftU64(EmitContext& ctx, IR::Inst* inst) { + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& lhs, const VectorArray& rhs) { + RoundingShiftLeft(result, lhs, rhs); + }); +} + static void VectorShuffleImpl(BlockOfCode& code, EmitContext& ctx, IR::Inst* inst, void (Xbyak::CodeGenerator::*fn)(const Xbyak::Mmx&, const Xbyak::Operand&, u8)) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 973d13f9..ac2ad63a 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -1424,6 +1424,38 @@ U128 IREmitter::VectorRoundingHalvingAddUnsigned(size_t esize, const U128& a, co return {}; } +U128 IREmitter::VectorRoundingShiftLeftSigned(size_t esize, const U128& a, const U128& b) { + switch (esize) { + case 8: + return Inst(Opcode::VectorRoundingShiftLeftS8, a, b); + case 16: + return Inst(Opcode::VectorRoundingShiftLeftS16, a, b); + case 32: + return Inst(Opcode::VectorRoundingShiftLeftS32, a, b); + case 64: + return Inst(Opcode::VectorRoundingShiftLeftS64, a, b); + } + + UNREACHABLE(); + return {}; +} + +U128 IREmitter::VectorRoundingShiftLeftUnsigned(size_t esize, const U128& a, const U128& b) { + switch (esize) { + case 8: + return Inst(Opcode::VectorRoundingShiftLeftU8, a, b); + case 16: + return Inst(Opcode::VectorRoundingShiftLeftU16, a, b); + case 32: + return Inst(Opcode::VectorRoundingShiftLeftU32, a, b); + case 64: + return Inst(Opcode::VectorRoundingShiftLeftU64, a, b); + } + + UNREACHABLE(); + return {}; +} + U128 IREmitter::VectorShuffleHighHalfwords(const U128& a, u8 mask) { return Inst(Opcode::VectorShuffleHighHalfwords, a, mask); } diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index ef87fe79..e33ac743 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -256,6 +256,8 @@ public: U128 VectorRotateRight(size_t esize, const U128& a, u8 amount); U128 VectorRoundingHalvingAddSigned(size_t esize, const U128& a, const U128& b); U128 VectorRoundingHalvingAddUnsigned(size_t esize, const U128& a, const U128& b); + U128 VectorRoundingShiftLeftSigned(size_t esize, const U128& a, const U128& b); + U128 VectorRoundingShiftLeftUnsigned(size_t esize, const U128& a, const U128& b); U128 VectorShuffleHighHalfwords(const U128& a, u8 mask); U128 VectorShuffleLowHalfwords(const U128& a, u8 mask); U128 VectorShuffleWords(const U128& a, u8 mask); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index d4f012c2..0a8e21fd 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -377,6 +377,14 @@ OPCODE(VectorRoundingHalvingAddS32, T::U128, T::U128, OPCODE(VectorRoundingHalvingAddU8, T::U128, T::U128, T::U128 ) OPCODE(VectorRoundingHalvingAddU16, T::U128, T::U128, T::U128 ) OPCODE(VectorRoundingHalvingAddU32, T::U128, T::U128, T::U128 ) +OPCODE(VectorRoundingShiftLeftS8, T::U128, T::U128, T::U128 ) +OPCODE(VectorRoundingShiftLeftS16, T::U128, T::U128, T::U128 ) +OPCODE(VectorRoundingShiftLeftS32, T::U128, T::U128, T::U128 ) +OPCODE(VectorRoundingShiftLeftS64, T::U128, T::U128, T::U128 ) +OPCODE(VectorRoundingShiftLeftU8, T::U128, T::U128, T::U128 ) +OPCODE(VectorRoundingShiftLeftU16, T::U128, T::U128, T::U128 ) +OPCODE(VectorRoundingShiftLeftU32, T::U128, T::U128, T::U128 ) +OPCODE(VectorRoundingShiftLeftU64, T::U128, T::U128, T::U128 ) OPCODE(VectorShuffleHighHalfwords, T::U128, T::U128, T::U8 ) OPCODE(VectorShuffleLowHalfwords, T::U128, T::U128, T::U8 ) OPCODE(VectorShuffleWords, T::U128, T::U128, T::U8 )