ir: Add VectorBroadcastElement{Lower} IR instruction

The lane-splatting variant of `FMUL` and `FMLA` is very
common in instruction streams when implementing things like
matrix multiplication. When used, they are used very densely.

https://community.arm.com/developer/ip-products/processors/b/processors-ip-blog/posts/coding-for-neon---part-3-matrix-multiplication

The way this is currently implemented is by grabbing the particular lane
into a general purpose register and then broadcasting it into a simd
register through `VectorGetElement` and `VectorBroadcast`.

```cpp
    const IR::U128 operand2 = v.ir.VectorBroadcast(esize, v.ir.VectorGetElement(esize, v.V(idxdsize, Vm), index));
```

What could be done instead is to keep it within
the vector-register and use a permute/shuffle to "splat" the particular
lane across all other lanes, removing the GPR-round-trip.

This is implemented as the new IR instruction `VectorBroadcastElement`:

```cpp
    const IR::U128 operand2 = v.ir.VectorBroadcastElement(esize, v.V(idxdsize, Vm), index);
```
This commit is contained in:
Wunkolo 2021-07-22 11:05:19 -07:00 committed by merry
parent 46b8cfabc0
commit 1e94acff66
9 changed files with 196 additions and 24 deletions

View file

@ -165,7 +165,7 @@ void EmitX64::EmitVectorGetElement8(EmitContext& ctx, IR::Inst* inst) {
if (code.HasHostFeature(HostFeature::SSE41)) {
code.pextrb(dest, source, index);
} else {
code.pextrw(dest, source, index / 2);
code.pextrw(dest, source, u8(index / 2));
if (index % 2 == 1) {
code.shr(dest, 8);
} else {
@ -752,6 +752,148 @@ void EmitX64::EmitVectorBroadcast64(EmitContext& ctx, IR::Inst* inst) {
ctx.reg_alloc.DefineValue(inst, a);
}
void EmitX64::EmitVectorBroadcastElementLower8(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);
ASSERT(args[1].IsImmediate());
const u8 index = args[1].GetImmediateU8();
ASSERT(index < 16);
if (index > 0) {
code.psrldq(a, index);
}
if (code.HasHostFeature(HostFeature::AVX2)) {
code.vpbroadcastb(a, a);
code.vmovq(a, a);
} else if (code.HasHostFeature(HostFeature::SSSE3)) {
const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
code.pxor(tmp, tmp);
code.pshufb(a, tmp);
code.movq(a, a);
} else {
code.punpcklbw(a, a);
code.pshuflw(a, a, 0);
}
ctx.reg_alloc.DefineValue(inst, a);
}
void EmitX64::EmitVectorBroadcastElementLower16(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);
ASSERT(args[1].IsImmediate());
const u8 index = args[1].GetImmediateU8();
ASSERT(index < 8);
if (index > 0) {
code.psrldq(a, u8(index * 2));
}
code.pshuflw(a, a, 0);
ctx.reg_alloc.DefineValue(inst, a);
}
void EmitX64::EmitVectorBroadcastElementLower32(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);
ASSERT(args[1].IsImmediate());
const u8 index = args[1].GetImmediateU8();
ASSERT(index < 4);
if (index > 0) {
code.psrldq(a, u8(index * 4));
}
code.pshuflw(a, a, 0b01'00'01'00);
ctx.reg_alloc.DefineValue(inst, a);
}
void EmitX64::EmitVectorBroadcastElement8(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);
ASSERT(args[1].IsImmediate());
const u8 index = args[1].GetImmediateU8();
ASSERT(index < 16);
if (index > 0) {
code.psrldq(a, index);
}
if (code.HasHostFeature(HostFeature::AVX2)) {
code.vpbroadcastb(a, a);
} else if (code.HasHostFeature(HostFeature::SSSE3)) {
const Xbyak::Xmm tmp = ctx.reg_alloc.ScratchXmm();
code.pxor(tmp, tmp);
code.pshufb(a, tmp);
} else {
code.punpcklbw(a, a);
code.pshuflw(a, a, 0);
code.punpcklqdq(a, a);
}
ctx.reg_alloc.DefineValue(inst, a);
}
void EmitX64::EmitVectorBroadcastElement16(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);
ASSERT(args[1].IsImmediate());
const u8 index = args[1].GetImmediateU8();
ASSERT(index < 8);
if (index == 0 && code.HasHostFeature(HostFeature::AVX2)) {
code.vpbroadcastw(a, a);
ctx.reg_alloc.DefineValue(inst, a);
return;
}
if (index < 4) {
code.pshuflw(a, a, Common::Replicate<u8>(index, 2));
code.punpcklqdq(a, a);
} else {
code.pshufhw(a, a, Common::Replicate<u8>(u8(index - 4), 2));
code.punpckhqdq(a, a);
}
ctx.reg_alloc.DefineValue(inst, a);
}
void EmitX64::EmitVectorBroadcastElement32(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);
ASSERT(args[1].IsImmediate());
const u8 index = args[1].GetImmediateU8();
ASSERT(index < 4);
code.pshufd(a, a, Common::Replicate<u8>(index, 2));
ctx.reg_alloc.DefineValue(inst, a);
}
void EmitX64::EmitVectorBroadcastElement64(EmitContext& ctx, IR::Inst* inst) {
auto args = ctx.reg_alloc.GetArgumentInfo(inst);
const Xbyak::Xmm a = ctx.reg_alloc.UseScratchXmm(args[0]);
ASSERT(args[1].IsImmediate());
const u8 index = args[1].GetImmediateU8();
ASSERT(index < 2);
if (code.HasHostFeature(HostFeature::AVX)) {
code.vpermilpd(a, a, Common::Replicate<u8>(index, 1));
} else {
if (index == 0) {
code.punpcklqdq(a, a);
} else {
code.punpckhqdq(a, a);
}
}
ctx.reg_alloc.DefineValue(inst, a);
}
template<typename T>
static void EmitVectorCountLeadingZeros(VectorArray<T>& result, const VectorArray<T>& data) {
for (size_t i = 0; i < result.size(); i++) {

View file

@ -80,8 +80,7 @@ bool TranslatorVisitor::asimd_VDUP_scalar(bool D, Imm<4> imm4, size_t Vd, bool Q
const auto m = ToVector(false, Vm, M);
const auto reg_m = ir.GetVector(m);
const auto scalar = ir.VectorGetElement(esize, reg_m, index);
const auto result = ir.VectorBroadcast(esize, scalar);
const auto result = ir.VectorBroadcastElement(esize, reg_m, index);
ir.SetVector(d, result);
return true;

View file

@ -46,9 +46,8 @@ bool ScalarMultiply(TranslatorVisitor& v, bool Q, bool D, size_t sz, size_t Vn,
const auto n = ToVector(Q, Vn, N);
const auto [m, index] = GetScalarLocation(esize, M, Vm);
const auto scalar = v.ir.VectorGetElement(esize, v.ir.GetVector(m), index);
const auto reg_n = v.ir.GetVector(n);
const auto reg_m = v.ir.VectorBroadcast(esize, scalar);
const auto reg_m = v.ir.VectorBroadcastElement(esize, v.ir.GetVector(m), index);
const auto addend = F ? v.ir.FPVectorMul(esize, reg_n, reg_m, false)
: v.ir.VectorMultiply(esize, reg_n, reg_m);
const auto result = [&] {
@ -125,9 +124,8 @@ bool ScalarMultiplyReturnHigh(TranslatorVisitor& v, bool Q, bool D, size_t sz, s
const auto n = ToVector(Q, Vn, N);
const auto [m, index] = GetScalarLocation(esize, M, Vm);
const auto scalar = v.ir.VectorGetElement(esize, v.ir.GetVector(m), index);
const auto reg_n = v.ir.GetVector(n);
const auto reg_m = v.ir.VectorBroadcast(esize, scalar);
const auto reg_m = v.ir.VectorBroadcastElement(esize, v.ir.GetVector(m), index);
const auto result = [&] {
const auto tmp = v.ir.VectorSignedSaturatedDoublingMultiply(esize, reg_n, reg_m);
@ -177,9 +175,8 @@ bool TranslatorVisitor::asimd_VQDMULL_scalar(bool D, size_t sz, size_t Vn, size_
const auto n = ToVector(false, Vn, N);
const auto [m, index] = GetScalarLocation(esize, M, Vm);
const auto scalar = ir.VectorGetElement(esize, ir.GetVector(m), index);
const auto reg_n = ir.GetVector(n);
const auto reg_m = ir.VectorBroadcast(esize, scalar);
const auto reg_m = ir.VectorBroadcastElement(esize, ir.GetVector(m), index);
const auto result = ir.VectorSignedSaturatedDoublingMultiplyLong(esize, reg_n, reg_m);
ir.SetVector(d, result);

View file

@ -41,8 +41,7 @@ bool TranslatorVisitor::DUP_elt_2(bool Q, Imm<5> imm5, Vec Vn, Vec Vd) {
const size_t datasize = Q ? 128 : 64;
const IR::U128 operand = V(idxdsize, Vn);
const IR::UAny element = ir.VectorGetElement(esize, operand, index);
const IR::U128 result = Q ? ir.VectorBroadcast(esize, element) : ir.VectorBroadcastLower(esize, element);
const IR::U128 result = Q ? ir.VectorBroadcastElement(esize, operand, index) : ir.VectorBroadcastElementLower(esize, operand, index);
V(datasize, Vd, result);
return true;
}

View file

@ -143,8 +143,8 @@ bool TranslatorVisitor::SQRDMULH_elt_1(Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> V
const auto [index, Vm] = Combine(size, H, L, M, Vmlo);
const IR::U128 operand1 = ir.ZeroExtendToQuad(ir.VectorGetElement(esize, V(128, Vn), 0));
const IR::UAny operand2 = ir.VectorGetElement(esize, V(128, Vm), index);
const IR::U128 broadcast = ir.VectorBroadcast(esize, operand2);
const IR::U128 operand2 = V(128, Vm);
const IR::U128 broadcast = ir.VectorBroadcastElement(esize, operand2, index);
const IR::UpperAndLower multiply = ir.VectorSignedSaturatedDoublingMultiply(esize, operand1, broadcast);
const IR::U128 result = ir.VectorAdd(esize, multiply.upper, ir.VectorLogicalShiftRight(esize, multiply.lower, static_cast<u8>(esize - 1)));
@ -161,8 +161,8 @@ bool TranslatorVisitor::SQDMULL_elt_1(Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> Vm
const auto [index, Vm] = Combine(size, H, L, M, Vmlo);
const IR::U128 operand1 = ir.ZeroExtendToQuad(ir.VectorGetElement(esize, V(128, Vn), 0));
const IR::UAny operand2 = ir.VectorGetElement(esize, V(128, Vm), index);
const IR::U128 broadcast = ir.VectorBroadcast(esize, operand2);
const IR::U128 operand2 = V(128, Vm);
const IR::U128 broadcast = ir.VectorBroadcastElement(esize, operand2, index);
const IR::U128 result = ir.VectorSignedSaturatedDoublingMultiplyLong(esize, operand1, broadcast);
V(128, Vd, result);

View file

@ -36,7 +36,7 @@ bool MultiplyByElement(TranslatorVisitor& v, bool Q, Imm<2> size, Imm<1> L, Imm<
const size_t datasize = Q ? 128 : 64;
const IR::U128 operand1 = v.V(datasize, Vn);
const IR::U128 operand2 = v.ir.VectorBroadcast(esize, v.ir.VectorGetElement(esize, v.V(idxdsize, Vm), index));
const IR::U128 operand2 = v.ir.VectorBroadcastElement(esize, v.V(idxdsize, Vm), index);
const IR::U128 operand3 = v.V(datasize, Vd);
IR::U128 result = v.ir.VectorMultiply(esize, operand1, operand2);
@ -64,9 +64,8 @@ bool FPMultiplyByElement(TranslatorVisitor& v, bool Q, bool sz, Imm<1> L, Imm<1>
const size_t esize = sz ? 64 : 32;
const size_t datasize = Q ? 128 : 64;
const IR::UAny element2 = v.ir.VectorGetElement(esize, v.V(idxdsize, Vm), index);
const IR::U128 operand1 = v.V(datasize, Vn);
const IR::U128 operand2 = Q ? v.ir.VectorBroadcast(esize, element2) : v.ir.VectorBroadcastLower(esize, element2);
const IR::U128 operand2 = Q ? v.ir.VectorBroadcastElement(esize, v.V(idxdsize, Vm), index) : v.ir.VectorBroadcastElementLower(esize, v.V(idxdsize, Vm), index);
const IR::U128 operand3 = v.V(datasize, Vd);
const IR::U128 result = [&] {
@ -93,9 +92,8 @@ bool FPMultiplyByElementHalfPrecision(TranslatorVisitor& v, bool Q, Imm<1> L, Im
const size_t esize = 16;
const size_t datasize = Q ? 128 : 64;
const IR::UAny element2 = v.ir.VectorGetElement(esize, v.V(idxdsize, Vm), index);
const IR::U128 operand1 = v.V(datasize, Vn);
const IR::U128 operand2 = Q ? v.ir.VectorBroadcast(esize, element2) : v.ir.VectorBroadcastLower(esize, element2);
const IR::U128 operand2 = Q ? v.ir.VectorBroadcastElement(esize, v.V(idxdsize, Vm), index) : v.ir.VectorBroadcastElementLower(esize, v.V(idxdsize, Vm), index);
const IR::U128 operand3 = v.V(datasize, Vd);
// TODO: We currently don't implement half-precision paths for
@ -179,7 +177,7 @@ bool MultiplyLong(TranslatorVisitor& v, bool Q, Imm<2> size, Imm<1> L, Imm<1> M,
const IR::U128 operand1 = v.Vpart(datasize, Vn, Q);
const IR::U128 operand2 = v.V(idxsize, Vm);
const IR::U128 index_vector = v.ir.VectorBroadcast(esize, v.ir.VectorGetElement(esize, operand2, index));
const IR::U128 index_vector = v.ir.VectorBroadcastElement(esize, operand2, index);
const IR::U128 result = [&] {
const auto [extended_op1, extended_index] = extend_operands(operand1, index_vector);
@ -349,7 +347,7 @@ bool TranslatorVisitor::SQDMULL_elt_2(bool Q, Imm<2> size, Imm<1> L, Imm<1> M, I
const IR::U128 operand1 = Vpart(datasize, Vn, part);
const IR::U128 operand2 = V(idxsize, Vm);
const IR::U128 index_vector = ir.VectorBroadcast(esize, ir.VectorGetElement(esize, operand2, index));
const IR::U128 index_vector = ir.VectorBroadcastElement(esize, operand2, index);
const IR::U128 result = ir.VectorSignedSaturatedDoublingMultiplyLong(esize, operand1, index_vector);
V(128, Vd, result);
@ -368,7 +366,7 @@ bool TranslatorVisitor::SQDMULH_elt_2(bool Q, Imm<2> size, Imm<1> L, Imm<1> M, I
const IR::U128 operand1 = V(datasize, Vn);
const IR::U128 operand2 = V(idxsize, Vm);
const IR::U128 index_vector = ir.VectorBroadcast(esize, ir.VectorGetElement(esize, operand2, index));
const IR::U128 index_vector = ir.VectorBroadcastElement(esize, operand2, index);
const IR::U128 result = ir.VectorSignedSaturatedDoublingMultiply(esize, operand1, index_vector).upper;
V(datasize, Vd, result);
@ -387,7 +385,7 @@ bool TranslatorVisitor::SQRDMULH_elt_2(bool Q, Imm<2> size, Imm<1> L, Imm<1> M,
const IR::U128 operand1 = V(datasize, Vn);
const IR::U128 operand2 = V(idxsize, Vm);
const IR::U128 index_vector = ir.VectorBroadcast(esize, ir.VectorGetElement(esize, operand2, index));
const IR::U128 index_vector = ir.VectorBroadcastElement(esize, operand2, index);
const IR::UpperAndLower multiply = ir.VectorSignedSaturatedDoublingMultiply(esize, operand1, index_vector);
const IR::U128 result = ir.VectorAdd(esize, multiply.upper, ir.VectorLogicalShiftRight(esize, multiply.lower, static_cast<u8>(esize - 1)));

View file

@ -1025,6 +1025,34 @@ U128 IREmitter::VectorBroadcast(size_t esize, const UAny& a) {
UNREACHABLE();
}
U128 IREmitter::VectorBroadcastElementLower(size_t esize, const U128& a, size_t index) {
ASSERT_MSG(esize * index < 128, "Invalid index");
switch (esize) {
case 8:
return Inst<U128>(Opcode::VectorBroadcastElementLower8, a, u8(index));
case 16:
return Inst<U128>(Opcode::VectorBroadcastElementLower16, a, u8(index));
case 32:
return Inst<U128>(Opcode::VectorBroadcastElementLower32, a, u8(index));
}
UNREACHABLE();
}
U128 IREmitter::VectorBroadcastElement(size_t esize, const U128& a, size_t index) {
ASSERT_MSG(esize * index < 128, "Invalid index");
switch (esize) {
case 8:
return Inst<U128>(Opcode::VectorBroadcastElement8, a, u8(index));
case 16:
return Inst<U128>(Opcode::VectorBroadcastElement16, a, u8(index));
case 32:
return Inst<U128>(Opcode::VectorBroadcastElement32, a, u8(index));
case 64:
return Inst<U128>(Opcode::VectorBroadcastElement64, a, u8(index));
}
UNREACHABLE();
}
U128 IREmitter::VectorCountLeadingZeros(size_t esize, const U128& a) {
switch (esize) {
case 8:

View file

@ -246,6 +246,8 @@ public:
U128 VectorArithmeticVShift(size_t esize, const U128& a, const U128& b);
U128 VectorBroadcast(size_t esize, const UAny& a);
U128 VectorBroadcastLower(size_t esize, const UAny& a);
U128 VectorBroadcastElement(size_t esize, const U128& a, size_t index);
U128 VectorBroadcastElementLower(size_t esize, const U128& a, size_t index);
U128 VectorCountLeadingZeros(size_t esize, const U128& a);
U128 VectorEor(const U128& a, const U128& b);
U128 VectorDeinterleaveEven(size_t esize, const U128& a, const U128& b);

View file

@ -309,6 +309,13 @@ OPCODE(VectorBroadcast8, U128, U8
OPCODE(VectorBroadcast16, U128, U16 )
OPCODE(VectorBroadcast32, U128, U32 )
OPCODE(VectorBroadcast64, U128, U64 )
OPCODE(VectorBroadcastElementLower8, U128, U128, U8 )
OPCODE(VectorBroadcastElementLower16, U128, U128, U8 )
OPCODE(VectorBroadcastElementLower32, U128, U128, U8 )
OPCODE(VectorBroadcastElement8, U128, U128, U8 )
OPCODE(VectorBroadcastElement16, U128, U128, U8 )
OPCODE(VectorBroadcastElement32, U128, U128, U8 )
OPCODE(VectorBroadcastElement64, U128, U128, U8 )
OPCODE(VectorCountLeadingZeros8, U128, U128 )
OPCODE(VectorCountLeadingZeros16, U128, U128 )
OPCODE(VectorCountLeadingZeros32, U128, U128 )