IR: Introduce VectorPaired{Min,Max}Lower

This commit is contained in:
Merry 2022-08-06 18:00:14 +01:00 committed by merry
parent 3df0eb30be
commit 8fb37e0e4f
6 changed files with 263 additions and 8 deletions

View file

@ -1150,6 +1150,102 @@ void EmitIR<IR::Opcode::VectorPairedMinU32>(oaknut::CodeGenerator& code, EmitCon
EmitThreeOpArranged<32>(code, ctx, inst, [&](auto Vresult, auto Va, auto Vb) { code.UMINP(Vresult, Va, Vb); }); EmitThreeOpArranged<32>(code, ctx, inst, [&](auto Vresult, auto Va, auto Vb) { code.UMINP(Vresult, Va, Vb); });
} }
template<>
void EmitIR<IR::Opcode::VectorPairedMaxLowerS8>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMaxLowerS16>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMaxLowerS32>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMaxLowerU8>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMaxLowerU16>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMaxLowerU32>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMinLowerS8>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMinLowerS16>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMinLowerS32>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMinLowerU8>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMinLowerU16>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<>
void EmitIR<IR::Opcode::VectorPairedMinLowerU32>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
(void)code;
(void)ctx;
(void)inst;
ASSERT_FALSE("Unimplemented");
}
template<> template<>
void EmitIR<IR::Opcode::VectorPolynomialMultiply8>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) { void EmitIR<IR::Opcode::VectorPolynomialMultiply8>(oaknut::CodeGenerator& code, EmitContext& ctx, IR::Inst* inst) {
EmitThreeOpArranged<8>(code, ctx, inst, [&](auto Vresult, auto Va, auto Vb) { code.PMUL(Vresult, Va, Vb); }); EmitThreeOpArranged<8>(code, ctx, inst, [&](auto Vresult, auto Va, auto Vb) { code.PMUL(Vresult, Va, Vb); });

View file

@ -2646,6 +2646,19 @@ static void PairedOperation(VectorArray<T>& result, const VectorArray<T>& x, con
} }
} }
template<typename T, typename Function>
static void LowerPairedOperation(VectorArray<T>& result, const VectorArray<T>& x, const VectorArray<T>& y, Function fn) {
const size_t range = x.size() / 4;
for (size_t i = 0; i < range; i++) {
result[i] = fn(x[2 * i], x[2 * i + 1]);
}
for (size_t i = 0; i < range; i++) {
result[range + i] = fn(y[2 * i], y[2 * i + 1]);
}
}
template<typename T> template<typename T>
static void PairedMax(VectorArray<T>& result, const VectorArray<T>& x, const VectorArray<T>& y) { static void PairedMax(VectorArray<T>& result, const VectorArray<T>& x, const VectorArray<T>& y) {
PairedOperation(result, x, y, [](auto a, auto b) { return std::max(a, b); }); PairedOperation(result, x, y, [](auto a, auto b) { return std::max(a, b); });
@ -2656,6 +2669,16 @@ static void PairedMin(VectorArray<T>& result, const VectorArray<T>& x, const Vec
PairedOperation(result, x, y, [](auto a, auto b) { return std::min(a, b); }); PairedOperation(result, x, y, [](auto a, auto b) { return std::min(a, b); });
} }
template<typename T>
static void LowerPairedMax(VectorArray<T>& result, const VectorArray<T>& x, const VectorArray<T>& y) {
LowerPairedOperation(result, x, y, [](auto a, auto b) { return std::max(a, b); });
}
template<typename T>
static void LowerPairedMin(VectorArray<T>& result, const VectorArray<T>& x, const VectorArray<T>& y) {
LowerPairedOperation(result, x, y, [](auto a, auto b) { return std::min(a, b); });
}
void EmitX64::EmitVectorPairedMaxS8(EmitContext& ctx, IR::Inst* inst) { void EmitX64::EmitVectorPairedMaxS8(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s8>& result, const VectorArray<s8>& a, const VectorArray<s8>& b) { EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s8>& result, const VectorArray<s8>& a, const VectorArray<s8>& b) {
PairedMax(result, a, b); PairedMax(result, a, b);
@ -2826,6 +2849,78 @@ void EmitX64::EmitVectorPairedMinU32(EmitContext& ctx, IR::Inst* inst) {
} }
} }
void EmitX64::EmitVectorPairedMaxLowerS8(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s8>& result, const VectorArray<s8>& a, const VectorArray<s8>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMaxLowerS16(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s16>& result, const VectorArray<s16>& a, const VectorArray<s16>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMaxLowerS32(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s16>& result, const VectorArray<s32>& a, const VectorArray<s32>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMaxLowerU8(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u8>& result, const VectorArray<u8>& a, const VectorArray<u8>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMaxLowerU16(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u16>& result, const VectorArray<u16>& a, const VectorArray<u16>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMaxLowerU32(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u32>& result, const VectorArray<u32>& a, const VectorArray<u32>& b) {
LowerPairedMax(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerS8(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s8>& result, const VectorArray<s8>& a, const VectorArray<s8>& b) {
LowerPairedMin(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerS16(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s16>& result, const VectorArray<s16>& a, const VectorArray<s16>& b) {
LowerPairedMin(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerS32(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<s32>& result, const VectorArray<s32>& a, const VectorArray<s32>& b) {
LowerPairedMin(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerU8(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u8>& result, const VectorArray<u8>& a, const VectorArray<u8>& b) {
LowerPairedMin(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerU16(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u16>& result, const VectorArray<u16>& a, const VectorArray<u16>& b) {
LowerPairedMin(result, a, b);
});
}
void EmitX64::EmitVectorPairedMinLowerU32(EmitContext& ctx, IR::Inst* inst) {
EmitTwoArgumentFallback(code, ctx, inst, [](VectorArray<u32>& result, const VectorArray<u32>& a, const VectorArray<u32>& b) {
LowerPairedMin(result, a, b);
});
}
template<typename D, typename T> template<typename D, typename T>
static D PolynomialMultiply(T lhs, T rhs) { static D PolynomialMultiply(T lhs, T rhs) {
constexpr size_t bit_size = mcl::bitsizeof<T>; constexpr size_t bit_size = mcl::bitsizeof<T>;

View file

@ -264,25 +264,21 @@ bool PairedMinMaxOperation(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vm, Ve
switch (operation) { switch (operation) {
case MinMaxOperation::Max: case MinMaxOperation::Max:
if (sign == Signedness::Signed) { if (sign == Signedness::Signed) {
return v.ir.VectorPairedMaxSigned(esize, operand1, operand2); return Q ? v.ir.VectorPairedMaxSigned(esize, operand1, operand2) : v.ir.VectorPairedMaxSignedLower(esize, operand1, operand2);
} }
return v.ir.VectorPairedMaxUnsigned(esize, operand1, operand2); return Q ? v.ir.VectorPairedMaxUnsigned(esize, operand1, operand2) : v.ir.VectorPairedMaxUnsignedLower(esize, operand1, operand2);
case MinMaxOperation::Min: case MinMaxOperation::Min:
if (sign == Signedness::Signed) { if (sign == Signedness::Signed) {
return v.ir.VectorPairedMinSigned(esize, operand1, operand2); return Q ? v.ir.VectorPairedMinSigned(esize, operand1, operand2) : v.ir.VectorPairedMinSignedLower(esize, operand1, operand2);
} }
return v.ir.VectorPairedMinUnsigned(esize, operand1, operand2); return Q ? v.ir.VectorPairedMinUnsigned(esize, operand1, operand2) : v.ir.VectorPairedMinUnsignedLower(esize, operand1, operand2);
default: default:
UNREACHABLE(); UNREACHABLE();
} }
}(); }();
if (datasize == 64) {
result = v.ir.VectorShuffleWords(result, 0b11011000);
}
v.V(datasize, Vd, result); v.V(datasize, Vd, result);
return true; return true;
} }

View file

@ -1550,6 +1550,58 @@ U128 IREmitter::VectorPairedMinUnsigned(size_t esize, const U128& a, const U128&
} }
} }
U128 IREmitter::VectorPairedMaxSignedLower(size_t esize, const U128& a, const U128& b) {
switch (esize) {
case 8:
return Inst<U128>(Opcode::VectorPairedMaxLowerS8, a, b);
case 16:
return Inst<U128>(Opcode::VectorPairedMaxLowerS16, a, b);
case 32:
return Inst<U128>(Opcode::VectorPairedMaxLowerS32, a, b);
default:
UNREACHABLE();
}
}
U128 IREmitter::VectorPairedMaxUnsignedLower(size_t esize, const U128& a, const U128& b) {
switch (esize) {
case 8:
return Inst<U128>(Opcode::VectorPairedMaxLowerU8, a, b);
case 16:
return Inst<U128>(Opcode::VectorPairedMaxLowerU16, a, b);
case 32:
return Inst<U128>(Opcode::VectorPairedMaxLowerU32, a, b);
default:
UNREACHABLE();
}
}
U128 IREmitter::VectorPairedMinSignedLower(size_t esize, const U128& a, const U128& b) {
switch (esize) {
case 8:
return Inst<U128>(Opcode::VectorPairedMinLowerS8, a, b);
case 16:
return Inst<U128>(Opcode::VectorPairedMinLowerS16, a, b);
case 32:
return Inst<U128>(Opcode::VectorPairedMinLowerS32, a, b);
default:
UNREACHABLE();
}
}
U128 IREmitter::VectorPairedMinUnsignedLower(size_t esize, const U128& a, const U128& b) {
switch (esize) {
case 8:
return Inst<U128>(Opcode::VectorPairedMinLowerU8, a, b);
case 16:
return Inst<U128>(Opcode::VectorPairedMinLowerU16, a, b);
case 32:
return Inst<U128>(Opcode::VectorPairedMinLowerU32, a, b);
default:
UNREACHABLE();
}
}
U128 IREmitter::VectorPolynomialMultiply(const U128& a, const U128& b) { U128 IREmitter::VectorPolynomialMultiply(const U128& a, const U128& b) {
return Inst<U128>(Opcode::VectorPolynomialMultiply8, a, b); return Inst<U128>(Opcode::VectorPolynomialMultiply8, a, b);
} }

View file

@ -277,6 +277,10 @@ public:
U128 VectorPairedMaxUnsigned(size_t esize, const U128& a, const U128& b); U128 VectorPairedMaxUnsigned(size_t esize, const U128& a, const U128& b);
U128 VectorPairedMinSigned(size_t esize, const U128& a, const U128& b); U128 VectorPairedMinSigned(size_t esize, const U128& a, const U128& b);
U128 VectorPairedMinUnsigned(size_t esize, const U128& a, const U128& b); U128 VectorPairedMinUnsigned(size_t esize, const U128& a, const U128& b);
U128 VectorPairedMaxSignedLower(size_t esize, const U128& a, const U128& b);
U128 VectorPairedMaxUnsignedLower(size_t esize, const U128& a, const U128& b);
U128 VectorPairedMinSignedLower(size_t esize, const U128& a, const U128& b);
U128 VectorPairedMinUnsignedLower(size_t esize, const U128& a, const U128& b);
U128 VectorPolynomialMultiply(const U128& a, const U128& b); U128 VectorPolynomialMultiply(const U128& a, const U128& b);
U128 VectorPolynomialMultiplyLong(size_t esize, const U128& a, const U128& b); U128 VectorPolynomialMultiplyLong(size_t esize, const U128& a, const U128& b);
U128 VectorPopulationCount(const U128& a); U128 VectorPopulationCount(const U128& a);

View file

@ -438,6 +438,18 @@ OPCODE(VectorPairedMinS32, U128, U128
OPCODE(VectorPairedMinU8, U128, U128, U128 ) OPCODE(VectorPairedMinU8, U128, U128, U128 )
OPCODE(VectorPairedMinU16, U128, U128, U128 ) OPCODE(VectorPairedMinU16, U128, U128, U128 )
OPCODE(VectorPairedMinU32, U128, U128, U128 ) OPCODE(VectorPairedMinU32, U128, U128, U128 )
OPCODE(VectorPairedMaxLowerS8, U128, U128, U128 )
OPCODE(VectorPairedMaxLowerS16, U128, U128, U128 )
OPCODE(VectorPairedMaxLowerS32, U128, U128, U128 )
OPCODE(VectorPairedMaxLowerU8, U128, U128, U128 )
OPCODE(VectorPairedMaxLowerU16, U128, U128, U128 )
OPCODE(VectorPairedMaxLowerU32, U128, U128, U128 )
OPCODE(VectorPairedMinLowerS8, U128, U128, U128 )
OPCODE(VectorPairedMinLowerS16, U128, U128, U128 )
OPCODE(VectorPairedMinLowerS32, U128, U128, U128 )
OPCODE(VectorPairedMinLowerU8, U128, U128, U128 )
OPCODE(VectorPairedMinLowerU16, U128, U128, U128 )
OPCODE(VectorPairedMinLowerU32, U128, U128, U128 )
OPCODE(VectorPolynomialMultiply8, U128, U128, U128 ) OPCODE(VectorPolynomialMultiply8, U128, U128, U128 )
OPCODE(VectorPolynomialMultiplyLong8, U128, U128, U128 ) OPCODE(VectorPolynomialMultiplyLong8, U128, U128, U128 )
OPCODE(VectorPolynomialMultiplyLong64, U128, U128, U128 ) OPCODE(VectorPolynomialMultiplyLong64, U128, U128, U128 )