simd_permute: Implement TRN{1,2} in terms of VectorTranspose

This commit is contained in:
MerryMage 2020-06-21 11:37:02 +01:00
parent 7d1e103ff5
commit a8b481ab63

View file

@ -13,8 +13,7 @@ enum class Transposition {
TRN2, TRN2,
}; };
bool VectorTranspose(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vm, Vec Vn, Vec Vd, bool VectorTranspose(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vm, Vec Vn, Vec Vd, Transposition type) {
Transposition type) {
if (!Q && size == 0b11) { if (!Q && size == 0b11) {
return v.ReservedValue(); return v.ReservedValue();
} }
@ -24,44 +23,7 @@ bool VectorTranspose(TranslatorVisitor& v, bool Q, Imm<2> size, Vec Vm, Vec Vn,
const IR::U128 m = v.V(datasize, Vm); const IR::U128 m = v.V(datasize, Vm);
const IR::U128 n = v.V(datasize, Vn); const IR::U128 n = v.V(datasize, Vn);
const IR::U128 result = v.ir.VectorTranspose(esize, n, m, type == Transposition::TRN2);
const IR::U128 result = [&] {
switch (esize) {
case 8:
case 16:
case 32: {
// Create a mask of elements we care about (e.g. for 8-bit: 0x00FF00FF00FF00FF for TRN1
// and 0xFF00FF00FF00FF00 for TRN2)
const u64 mask_element = [&] {
const size_t shift = type == Transposition::TRN1 ? 0 : esize;
return Common::Ones<u64>(esize) << shift;
}();
const size_t doubled_esize = esize * 2;
const u64 mask_value = Common::Replicate<u64>(mask_element, doubled_esize);
const IR::U128 mask = v.ir.VectorBroadcast(64, v.I(64, mask_value));
const IR::U128 anded_m = v.ir.VectorAnd(m, mask);
const IR::U128 anded_n = v.ir.VectorAnd(n, mask);
if (type == Transposition::TRN1) {
return v.ir.VectorOr(v.ir.VectorLogicalShiftLeft(doubled_esize, anded_m, esize), anded_n);
}
return v.ir.VectorOr(v.ir.VectorLogicalShiftRight(doubled_esize, anded_n, esize), anded_m);
}
case 64: {
default:
const auto [src, src_idx, dst, dst_idx] = [type, m, n] {
if (type == Transposition::TRN1) {
return std::make_tuple(m, 0, n, 1);
}
return std::make_tuple(n, 1, m, 0);
}();
return v.ir.VectorSetElement(esize, dst, dst_idx, v.ir.VectorGetElement(esize, src, src_idx));
}
}
}();
v.V(datasize, Vd, result); v.V(datasize, Vd, result);
return true; return true;