diff --git a/src/frontend/A64/decoder/a64.inc b/src/frontend/A64/decoder/a64.inc index 71604561..0169241a 100644 --- a/src/frontend/A64/decoder/a64.inc +++ b/src/frontend/A64/decoder/a64.inc @@ -829,7 +829,7 @@ INST(MUL_elt, "MUL (by element)", "0Q001 //INST(SQDMULL_elt_2, "SQDMULL, SQDMULL2 (by element)", "0Q001111zzLMmmmm1011H0nnnnnddddd") //INST(SQDMULH_elt_2, "SQDMULH (by element)", "0Q001111zzLMmmmm1100H0nnnnnddddd") //INST(SQRDMULH_elt_2, "SQRDMULH (by element)", "0Q001111zzLMmmmm1101H0nnnnnddddd") -//INST(SDOT_elt, "SDOT (by element)", "0Q001111zzLMmmmm1110H0nnnnnddddd") +INST(SDOT_elt, "SDOT (by element)", "0Q001111zzLMmmmm1110H0nnnnnddddd") //INST(FMLA_elt_3, "FMLA (by element)", "0Q00111100LMmmmm0001H0nnnnnddddd") INST(FMLA_elt_4, "FMLA (by element)", "0Q0011111zLMmmmm0001H0nnnnnddddd") //INST(FMLS_elt_3, "FMLS (by element)", "0Q00111100LMmmmm0101H0nnnnnddddd") @@ -846,7 +846,7 @@ INST(MLS_elt, "MLS (by element)", "0Q101 //INST(UMLSL_elt, "UMLSL, UMLSL2 (by element)", "0Q101111zzLMmmmm0110H0nnnnnddddd") //INST(UMULL_elt, "UMULL, UMULL2 (by element)", "0Q101111zzLMmmmm1010H0nnnnnddddd") //INST(SQRDMLAH_elt_2, "SQRDMLAH (by element)", "0Q101111zzLMmmmm1101H0nnnnnddddd") -//INST(UDOT_elt, "UDOT (by element)", "0Q101111zzLMmmmm1110H0nnnnnddddd") +INST(UDOT_elt, "UDOT (by element)", "0Q101111zzLMmmmm1110H0nnnnnddddd") //INST(SQRDMLSH_elt_2, "SQRDMLSH (by element)", "0Q101111zzLMmmmm1111H0nnnnnddddd") //INST(FMULX_elt_3, "FMULX (by element)", "0Q10111100LMmmmm1001H0nnnnnddddd") //INST(FMULX_elt_4, "FMULX (by element)", "0Q1011111zLMmmmm1001H0nnnnnddddd") diff --git a/src/frontend/A64/translate/impl/impl.h b/src/frontend/A64/translate/impl/impl.h index 0342ede6..59c1861b 100644 --- a/src/frontend/A64/translate/impl/impl.h +++ b/src/frontend/A64/translate/impl/impl.h @@ -900,7 +900,7 @@ struct TranslatorVisitor final { bool SQDMULL_elt_2(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Reg Rn, Vec Vd); bool SQDMULH_elt_2(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); bool SQRDMULH_elt_2(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); - bool SDOT_elt(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); + bool SDOT_elt(bool Q, Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd); bool FMLA_elt_3(bool Q, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); bool FMLA_elt_4(bool Q, bool sz, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd); bool FMLS_elt_3(bool Q, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); @@ -917,7 +917,7 @@ struct TranslatorVisitor final { bool UMLSL_elt(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); bool UMULL_elt(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); bool SQRDMLAH_elt_2(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); - bool UDOT_elt(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); + bool UDOT_elt(bool Q, Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd); bool SQRDMLSH_elt_2(bool Q, Imm<2> size, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); bool FMULX_elt_3(bool Q, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); bool FMULX_elt_4(bool Q, bool sz, bool L, bool M, Vec Vm, bool H, Vec Vn, Vec Vd); diff --git a/src/frontend/A64/translate/impl/simd_vector_x_indexed_element.cpp b/src/frontend/A64/translate/impl/simd_vector_x_indexed_element.cpp index c1aa8347..aef2a807 100644 --- a/src/frontend/A64/translate/impl/simd_vector_x_indexed_element.cpp +++ b/src/frontend/A64/translate/impl/simd_vector_x_indexed_element.cpp @@ -86,6 +86,41 @@ bool FPMultiplyByElement(TranslatorVisitor& v, bool Q, bool sz, Imm<1> L, Imm<1> return true; } +using ExtensionFunction = IR::U32 (IREmitter::*)(const IR::UAny&); + +bool DotProduct(TranslatorVisitor& v, bool Q, Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, + Vec Vn, Vec Vd, ExtensionFunction extension) { + if (size != 0b10) { + return v.ReservedValue(); + } + + const Vec Vm = concatenate(M, Vmlo).ZeroExtend(); + const size_t esize = 8 << size.ZeroExtend(); + const size_t datasize = Q ? 128 : 64; + const size_t elements = datasize / esize; + const size_t index = concatenate(H, L).ZeroExtend(); + + const IR::U128 operand1 = v.V(datasize, Vn); + const IR::U128 operand2 = v.V(128, Vm); + IR::U128 result = v.V(datasize, Vd); + + for (size_t i = 0; i < elements; i++) { + IR::U32 res_element = v.ir.Imm32(0); + + for (size_t j = 0; j < 4; j++) { + const IR::U32 elem1 = (v.ir.*extension)(v.ir.VectorGetElement(8, operand1, 4 * i + j)); + const IR::U32 elem2 = (v.ir.*extension)(v.ir.VectorGetElement(8, operand2, 4 * index + j)); + + res_element = v.ir.Add(res_element, v.ir.Mul(elem1, elem2)); + } + + res_element = v.ir.Add(v.ir.VectorGetElement(32, result, i), res_element); + result = v.ir.VectorSetElement(32, result, i, res_element); + } + + v.V(datasize, Vd, result); + return true; +} } // Anonymous namespace bool TranslatorVisitor::MLA_elt(bool Q, Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { @@ -112,4 +147,12 @@ bool TranslatorVisitor::FMUL_elt_4(bool Q, bool sz, Imm<1> L, Imm<1> M, Imm<4> V return FPMultiplyByElement(*this, Q, sz, L, M, Vmlo, H, Vn, Vd, ExtraBehavior::None); } +bool TranslatorVisitor::SDOT_elt(bool Q, Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { + return DotProduct(*this, Q, size, L, M, Vmlo, H, Vn, Vd, &IREmitter::SignExtendToWord); +} + +bool TranslatorVisitor::UDOT_elt(bool Q, Imm<2> size, Imm<1> L, Imm<1> M, Imm<4> Vmlo, Imm<1> H, Vec Vn, Vec Vd) { + return DotProduct(*this, Q, size, L, M, Vmlo, H, Vn, Vd, &IREmitter::ZeroExtendToWord); +} + } // namespace Dynarmic::A64