diff --git a/src/backend_x64/emit_x64.cpp b/src/backend_x64/emit_x64.cpp index 3f9ee072..d68de55c 100644 --- a/src/backend_x64/emit_x64.cpp +++ b/src/backend_x64/emit_x64.cpp @@ -1761,25 +1761,39 @@ void EmitX64::EmitPackedHalvingAddU8(RegAlloc& reg_alloc, IR::Block&, IR::Inst* void EmitX64::EmitPackedHalvingAddU16(RegAlloc& reg_alloc, IR::Block&, IR::Inst* inst) { auto args = reg_alloc.GetArgumentInfo(inst); - Xbyak::Reg32 reg_a = reg_alloc.UseScratchGpr(args[0]).cvt32(); - Xbyak::Reg32 reg_b = reg_alloc.UseGpr(args[1]).cvt32(); - Xbyak::Reg32 xor_a_b = reg_alloc.ScratchGpr().cvt32(); - Xbyak::Reg32 and_a_b = reg_a; - Xbyak::Reg32 result = reg_a; + if (args[0].IsInXmm() || args[1].IsInXmm()) { + Xbyak::Xmm xmm_a = reg_alloc.UseScratchXmm(args[0]); + Xbyak::Xmm xmm_b = reg_alloc.UseXmm(args[1]); + Xbyak::Xmm tmp = reg_alloc.ScratchXmm(); - // This relies on the equality x+y == ((x&y) << 1) + (x^y). - // Note that x^y always contains the LSB of the result. - // Since we want to calculate (x+y)/2, we can instead calculate (x&y) + ((x^y)>>1). - // We mask by 0x7FFF to remove the LSB so that it doesn't leak into the field below. + code->movdqa(tmp, xmm_a); + code->pand(xmm_a, xmm_b); + code->pxor(tmp, xmm_b); + code->psrlw(tmp, 1); + code->paddw(xmm_a, tmp); - code->mov(xor_a_b, reg_a); - code->and_(and_a_b, reg_b); - code->xor_(xor_a_b, reg_b); - code->shr(xor_a_b, 1); - code->and_(xor_a_b, 0x7FFF7FFF); - code->add(result, xor_a_b); + reg_alloc.DefineValue(inst, xmm_a); + } else { + Xbyak::Reg32 reg_a = reg_alloc.UseScratchGpr(args[0]).cvt32(); + Xbyak::Reg32 reg_b = reg_alloc.UseGpr(args[1]).cvt32(); + Xbyak::Reg32 xor_a_b = reg_alloc.ScratchGpr().cvt32(); + Xbyak::Reg32 and_a_b = reg_a; + Xbyak::Reg32 result = reg_a; - reg_alloc.DefineValue(inst, result); + // This relies on the equality x+y == ((x&y) << 1) + (x^y). + // Note that x^y always contains the LSB of the result. + // Since we want to calculate (x+y)/2, we can instead calculate (x&y) + ((x^y)>>1). + // We mask by 0x7FFF to remove the LSB so that it doesn't leak into the field below. + + code->mov(xor_a_b, reg_a); + code->and_(and_a_b, reg_b); + code->xor_(xor_a_b, reg_b); + code->shr(xor_a_b, 1); + code->and_(xor_a_b, 0x7FFF7FFF); + code->add(result, xor_a_b); + + reg_alloc.DefineValue(inst, result); + } } void EmitX64::EmitPackedHalvingAddS8(RegAlloc& reg_alloc, IR::Block&, IR::Inst* inst) { @@ -1814,30 +1828,22 @@ void EmitX64::EmitPackedHalvingAddS8(RegAlloc& reg_alloc, IR::Block&, IR::Inst* void EmitX64::EmitPackedHalvingAddS16(RegAlloc& reg_alloc, IR::Block&, IR::Inst* inst) { auto args = reg_alloc.GetArgumentInfo(inst); - Xbyak::Reg32 reg_a = reg_alloc.UseScratchGpr(args[0]).cvt32(); - Xbyak::Reg32 reg_b = reg_alloc.UseGpr(args[1]).cvt32(); - Xbyak::Reg32 xor_a_b = reg_alloc.ScratchGpr().cvt32(); - Xbyak::Reg32 and_a_b = reg_a; - Xbyak::Reg32 result = reg_a; - Xbyak::Reg32 carry = reg_alloc.ScratchGpr().cvt32(); + Xbyak::Xmm xmm_a = reg_alloc.UseScratchXmm(args[0]); + Xbyak::Xmm xmm_b = reg_alloc.UseXmm(args[1]); + Xbyak::Xmm tmp = reg_alloc.ScratchXmm(); // This relies on the equality x+y == ((x&y) << 1) + (x^y). // Note that x^y always contains the LSB of the result. - // Since we want to calculate (x+y)/2, we can instead calculate (x&y) + ((x^y)>>1). - // We mask by 0x7FFF to remove the LSB so that it doesn't leak into the field below. - // carry propagates the sign bit from (x^y)>>1 upwards by one. + // Since we want to calculate (x+y)/2, we can instead calculate (x&y) + ((x^y)>>>1). + // The arithmetic shift right makes this signed. - code->mov(xor_a_b, reg_a); - code->and_(and_a_b, reg_b); - code->xor_(xor_a_b, reg_b); - code->mov(carry, xor_a_b); - code->and_(carry, 0x80008000); - code->shr(xor_a_b, 1); - code->and_(xor_a_b, 0x7FFF7FFF); - code->add(result, xor_a_b); - code->xor_(result, carry); + code->movdqa(tmp, xmm_a); + code->pand(xmm_a, xmm_b); + code->pxor(tmp, xmm_b); + code->psraw(tmp, 1); + code->paddw(xmm_a, tmp); - reg_alloc.DefineValue(inst, result); + reg_alloc.DefineValue(inst, xmm_a); } void EmitX64::EmitPackedHalvingSubU8(RegAlloc& reg_alloc, IR::Block&, IR::Inst* inst) {