diff --git a/src/backend/x64/emit_x64_vector.cpp b/src/backend/x64/emit_x64_vector.cpp index 9aab41ef..c9024d8a 100644 --- a/src/backend/x64/emit_x64_vector.cpp +++ b/src/backend/x64/emit_x64_vector.cpp @@ -1371,6 +1371,32 @@ void EmitX64::EmitVectorLogicalVShiftS8(EmitContext& ctx, IR::Inst* inst) { } void EmitX64::EmitVectorLogicalVShiftS16(EmitContext& ctx, IR::Inst* inst) { + if (code.DoesCpuSupport(Xbyak::util::Cpu::tAVX512VL) && code.DoesCpuSupport(Xbyak::util::Cpu::tAVX512BW)) { + auto args = ctx.reg_alloc.GetArgumentInfo(inst); + + const Xbyak::Xmm result = ctx.reg_alloc.UseScratchXmm(args[0]); + const Xbyak::Xmm left_shift = ctx.reg_alloc.UseScratchXmm(args[1]); + const Xbyak::Xmm right_shift = ctx.reg_alloc.ScratchXmm(); + const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm(); + + code.vmovdqa(tmp, code.MConst(xword, 0x00FF00FF00FF00FF, 0x00FF00FF00FF00FF)); + code.vpxor(right_shift, right_shift, right_shift); + code.vpsubw(right_shift, right_shift, left_shift); + + code.vpsllw(xmm0, left_shift, 8); + code.vpsraw(xmm0, xmm0, 15); + + code.vpand(right_shift, right_shift, tmp); + code.vpand(left_shift, left_shift, tmp); + + code.vpsravw(tmp, result, right_shift); + code.vpsllvw(result, result, left_shift); + code.pblendvb(result, tmp); + + ctx.reg_alloc.DefineValue(inst, result); + return; + } + EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray& result, const VectorArray& a, const VectorArray& b) { std::transform(a.begin(), a.end(), b.begin(), result.begin(), LogicalVShift); });