diff --git a/src/frontend/A64/decoder/a64.inc b/src/frontend/A64/decoder/a64.inc index 844fc97c..b513834c 100644 --- a/src/frontend/A64/decoder/a64.inc +++ b/src/frontend/A64/decoder/a64.inc @@ -859,7 +859,7 @@ INST(SM3TT2B, "SM3TT2B", "11001 // Data Processing - FP and SIMD - SHA512 three register INST(SHA512H, "SHA512H", "11001110011mmmmm100000nnnnnddddd") -//INST(SHA512H2, "SHA512H2", "11001110011mmmmm100001nnnnnddddd") +INST(SHA512H2, "SHA512H2", "11001110011mmmmm100001nnnnnddddd") INST(SHA512SU1, "SHA512SU1", "11001110011mmmmm100010nnnnnddddd") INST(RAX1, "RAX1", "11001110011mmmmm100011nnnnnddddd") INST(SM3PARTW1, "SM3PARTW1", "11001110011mmmmm110000nnnnnddddd") diff --git a/src/frontend/A64/translate/impl/simd_sha512.cpp b/src/frontend/A64/translate/impl/simd_sha512.cpp index 9189005a..955092ee 100644 --- a/src/frontend/A64/translate/impl/simd_sha512.cpp +++ b/src/frontend/A64/translate/impl/simd_sha512.cpp @@ -23,6 +23,82 @@ IR::U64 MakeMNSig(IREmitter& ir, IR::U64 data, u8 first_rot_amount, u8 second_ro return ir.Eor(tmp1, ir.Eor(tmp2, tmp3)); } + +enum class SHA512HashPart { + Part1, + Part2, +}; + +IR::U128 SHA512Hash(IREmitter& ir, Vec Vm, Vec Vn, Vec Vd, SHA512HashPart part) { + const IR::U128 x = ir.GetQ(Vn); + const IR::U128 y = ir.GetQ(Vm); + const IR::U128 w = ir.GetQ(Vd); + + const IR::U64 lower_x = ir.VectorGetElement(64, x, 0); + const IR::U64 upper_x = ir.VectorGetElement(64, x, 1); + + const IR::U64 lower_y = ir.VectorGetElement(64, y, 0); + const IR::U64 upper_y = ir.VectorGetElement(64, y, 1); + + const auto make_sigma = [&](IR::U64 data) { + if (part == SHA512HashPart::Part1) { + return MakeMNSig(ir, data, 14, 18, 41); + } + return MakeMNSig(ir, data, 28, 34, 39); + }; + + const auto make_partial_half = [&](IR::U64 a, IR::U64 b, IR::U64 c) { + const IR::U64 tmp1 = ir.And(a, b); + + if (part == SHA512HashPart::Part1) { + const IR::U64 tmp2 = ir.And(ir.Not(a), c); + return ir.Eor(tmp1, tmp2); + } + + const IR::U64 tmp2 = ir.And(a, c); + const IR::U64 tmp3 = ir.And(upper_y, lower_y); + return ir.Eor(tmp1, ir.Eor(tmp2, tmp3)); + }; + + const IR::U64 Vtmp = [&] { + const IR::U64 partial = [&] { + if (part == SHA512HashPart::Part1) { + return make_partial_half(upper_y, lower_x, upper_x); + } + return make_partial_half(lower_x, upper_y, lower_y); + }(); + const IR::U64 upper_w = ir.VectorGetElement(64, w, 1); + const IR::U64 sig = [&] { + if (part == SHA512HashPart::Part1) { + return make_sigma(upper_y); + } + return make_sigma(lower_y); + }(); + + return ir.Add(partial, ir.Add(sig, upper_w)); + }(); + + const IR::U128 low_result = [&] { + const IR::U64 tmp = [&]() -> IR::U64 { + if (part == SHA512HashPart::Part1) { + return ir.Add(Vtmp, lower_y); + } + return Vtmp; + }(); + const IR::U64 partial = [&] { + if (part == SHA512HashPart::Part1) { + return make_partial_half(tmp, upper_y, lower_x); + } + return make_partial_half(Vtmp, lower_y, upper_y); + }(); + const IR::U64 sig = make_sigma(tmp); + const IR::U64 lower_w = ir.VectorGetElement(64, w, 0); + + return ir.ZeroExtendToQuad(ir.Add(partial, ir.Add(sig, lower_w))); + }(); + + return ir.VectorSetElement(64, low_result, 1, Vtmp); +} } // Anonymous namespace bool TranslatorVisitor::SHA512SU0(Vec Vn, Vec Vd) { @@ -69,45 +145,13 @@ bool TranslatorVisitor::SHA512SU1(Vec Vm, Vec Vn, Vec Vd) { } bool TranslatorVisitor::SHA512H(Vec Vm, Vec Vn, Vec Vd) { - const IR::U128 x = ir.GetQ(Vn); - const IR::U128 y = ir.GetQ(Vm); - const IR::U128 w = ir.GetQ(Vd); - - const IR::U64 lower_x = ir.VectorGetElement(64, x, 0); - const IR::U64 upper_x = ir.VectorGetElement(64, x, 1); - - const IR::U64 lower_y = ir.VectorGetElement(64, y, 0); - const IR::U64 upper_y = ir.VectorGetElement(64, y, 1); - - const auto make_msigma = [&](IR::U64 data) { - return MakeMNSig(ir, data, 14, 18, 41); - }; - - const auto make_partial_half = [](IREmitter& ir, IR::U64 a, IR::U64 b, IR::U64 c) { - const IR::U64 tmp1 = ir.And(a, b); - const IR::U64 tmp2 = ir.And(ir.Not(a), c); - return ir.Eor(tmp1, tmp2); - }; - - const IR::U64 Vtmp = [&] { - const IR::U64 upper_w = ir.VectorGetElement(64, w, 1); - const IR::U64 partial = make_partial_half(ir, upper_y, lower_x, upper_x); - const IR::U64 sig = make_msigma(upper_y); - - return ir.Add(partial, ir.Add(sig, upper_w)); - }(); - const IR::U64 tmp = ir.Add(Vtmp, lower_y); - - const IR::U128 low_result = [&] { - const IR::U64 lower_w = ir.VectorGetElement(64, w, 0); - const IR::U64 partial = make_partial_half(ir, tmp, upper_y, lower_x); - const IR::U64 sig = make_msigma(tmp); - - return ir.ZeroExtendToQuad(ir.Add(partial, ir.Add(sig, lower_w))); - }(); - - const IR::U128 result = ir.VectorSetElement(64, low_result, 1, Vtmp); + const IR::U128 result = SHA512Hash(ir, Vm, Vn, Vd, SHA512HashPart::Part1); + ir.SetQ(Vd, result); + return true; +} +bool TranslatorVisitor::SHA512H2(Vec Vm, Vec Vn, Vec Vd) { + const IR::U128 result = SHA512Hash(ir, Vm, Vn, Vd, SHA512HashPart::Part2); ir.SetQ(Vd, result); return true; }