From e739624296186a20ac1b969447679a191647832d Mon Sep 17 00:00:00 2001 From: Lioncash Date: Sun, 9 Sep 2018 17:06:47 -0400 Subject: [PATCH] ir: Add opcodes for vector CLZ operations We can optimize these cases further for with the use of a fair bit of shuffling via pshufb and the use of masks, but given the uncommon use of this instruction, I wouldn't consider it to be beneficial in terms of amount of code to be worth it over a simple manageable naive solution like this. If we ever do hit a case where vectorized CLZ happens to be a bottleneck, then we can revisit this. At least with AVX-512CD, this can be done with a single instruction for the 32-bit word case. --- src/backend/x64/emit_x64_vector.cpp | 37 +++++++++++++++++++++++++++++ src/frontend/ir/ir_emitter.cpp | 13 ++++++++++ src/frontend/ir/ir_emitter.h | 1 + src/frontend/ir/opcodes.inc | 3 +++ 4 files changed, 54 insertions(+) diff --git a/src/backend/x64/emit_x64_vector.cpp b/src/backend/x64/emit_x64_vector.cpp index ae303c4c..838dc6cb 100644 --- a/src/backend/x64/emit_x64_vector.cpp +++ b/src/backend/x64/emit_x64_vector.cpp @@ -616,6 +616,43 @@ void EmitX64::EmitVectorBroadcast64(EmitContext& ctx, IR::Inst* inst) { ctx.reg_alloc.DefineValue(inst, a); } +template +static void EmitVectorCountLeadingZeros(VectorArray& result, const VectorArray& data) { + for (size_t i = 0; i < result.size(); i++) { + T element = data[i]; + + size_t count = Common::BitSize(); + while (element != 0) { + element >>= 1; + --count; + } + + result[i] = static_cast(count); + } +} + +void EmitX64::EmitVectorCountLeadingZeros8(EmitContext& ctx, IR::Inst* inst) { + EmitOneArgumentFallback(code, ctx, inst, EmitVectorCountLeadingZeros); +} + +void EmitX64::EmitVectorCountLeadingZeros16(EmitContext& ctx, IR::Inst* inst) { + EmitOneArgumentFallback(code, ctx, inst, EmitVectorCountLeadingZeros); +} + +void EmitX64::EmitVectorCountLeadingZeros32(EmitContext& ctx, IR::Inst* inst) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tAVX512CD) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX512VL)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + const Xbyak::Xmm data = ctx.reg_alloc.UseScratchXmm(args[0]); + code.vplzcntd(data, data); + + ctx.reg_alloc.DefineValue(inst, data); + return; + } + + EmitOneArgumentFallback(code, ctx, inst, EmitVectorCountLeadingZeros); +} + void EmitX64::EmitVectorDeinterleaveEven8(EmitContext& ctx, IR::Inst* inst) { auto args = ctx.reg_alloc.GetArgumentInfo(inst); const Xbyak::Xmm lhs = ctx.reg_alloc.UseScratchXmm(args[0]); diff --git a/src/frontend/ir/ir_emitter.cpp b/src/frontend/ir/ir_emitter.cpp index 7c79e42d..daf6e675 100644 --- a/src/frontend/ir/ir_emitter.cpp +++ b/src/frontend/ir/ir_emitter.cpp @@ -916,6 +916,19 @@ U128 IREmitter::VectorBroadcast(size_t esize, const UAny& a) { return {}; } +U128 IREmitter::VectorCountLeadingZeros(size_t esize, const U128& a) { + switch (esize) { + case 8: + return Inst(Opcode::VectorCountLeadingZeros8, a); + case 16: + return Inst(Opcode::VectorCountLeadingZeros16, a); + case 32: + return Inst(Opcode::VectorCountLeadingZeros32, a); + } + UNREACHABLE(); + return {}; +} + U128 IREmitter::VectorDeinterleaveEven(size_t esize, const U128& a, const U128& b) { switch (esize) { case 8: diff --git a/src/frontend/ir/ir_emitter.h b/src/frontend/ir/ir_emitter.h index 79d1d219..9dc58d82 100644 --- a/src/frontend/ir/ir_emitter.h +++ b/src/frontend/ir/ir_emitter.h @@ -209,6 +209,7 @@ public: U128 VectorArithmeticShiftRight(size_t esize, const U128& a, u8 shift_amount); U128 VectorBroadcast(size_t esize, const UAny& a); U128 VectorBroadcastLower(size_t esize, const UAny& a); + U128 VectorCountLeadingZeros(size_t esize, const U128& a); U128 VectorEor(const U128& a, const U128& b); U128 VectorDeinterleaveEven(size_t esize, const U128& a, const U128& b); U128 VectorDeinterleaveOdd(size_t esize, const U128& a, const U128& b); diff --git a/src/frontend/ir/opcodes.inc b/src/frontend/ir/opcodes.inc index d8fbc54a..3977e0dc 100644 --- a/src/frontend/ir/opcodes.inc +++ b/src/frontend/ir/opcodes.inc @@ -258,6 +258,9 @@ OPCODE(VectorBroadcast8, U128, U8 OPCODE(VectorBroadcast16, U128, U16 ) OPCODE(VectorBroadcast32, U128, U32 ) OPCODE(VectorBroadcast64, U128, U64 ) +OPCODE(VectorCountLeadingZeros8, U128, U128 ) +OPCODE(VectorCountLeadingZeros16, U128, U128 ) +OPCODE(VectorCountLeadingZeros32, U128, U128 ) OPCODE(VectorDeinterleaveEven8, U128, U128, U128 ) OPCODE(VectorDeinterleaveEven16, U128, U128, U128 ) OPCODE(VectorDeinterleaveEven32, U128, U128, U128 )