constant_propagation_pass: Fold shifts

This commit is contained in:
MerryMage 2020-04-21 23:36:55 +01:00
parent 7242388577
commit df1a0eecaf

View file

@ -8,6 +8,7 @@
#include "common/assert.h"
#include "common/bit_util.h"
#include "common/safe_ops.h"
#include "common/common_types.h"
#include "frontend/ir/basic_block.h"
#include "frontend/ir/ir_emitter.h"
@ -254,7 +255,7 @@ void FoldOR(IR::Inst& inst, bool is_32_bit) {
}
}
void FoldShifts(IR::Inst& inst) {
bool FoldShifts(IR::Inst& inst) {
IR::Inst* carry_inst = inst.GetAssociatedPseudoOperation(Op::GetCarryFromOp);
// The 32-bit variants can contain 3 arguments, while the
@ -264,14 +265,19 @@ void FoldShifts(IR::Inst& inst) {
}
const auto shift_amount = inst.GetArg(1);
if (!shift_amount.IsZero()) {
return;
if (shift_amount.IsZero()) {
if (carry_inst) {
carry_inst->ReplaceUsesWith(inst.GetArg(2));
}
inst.ReplaceUsesWith(inst.GetArg(0));
return false;
}
if (carry_inst) {
carry_inst->ReplaceUsesWith(inst.GetArg(2));
if (!inst.AreAllArgsImmediates() || carry_inst) {
return false;
}
inst.ReplaceUsesWith(inst.GetArg(0));
return true;
}
void FoldSignExtendXToWord(IR::Inst& inst) {
@ -332,14 +338,84 @@ void ConstantPropagation(IR::Block& block) {
FoldMostSignificantBit(inst);
break;
case Op::LogicalShiftLeft32:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, true, Safe::LogicalShiftLeft<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
}
break;
case Op::LogicalShiftLeft64:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, false, Safe::LogicalShiftLeft<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
}
break;
case Op::LogicalShiftRight32:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, true, Safe::LogicalShiftRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
}
break;
case Op::LogicalShiftRight64:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, false, Safe::LogicalShiftRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
}
break;
case Op::ArithmeticShiftRight32:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, true, Safe::ArithmeticShiftRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
}
break;
case Op::ArithmeticShiftRight64:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, false, Safe::ArithmeticShiftRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
}
break;
case Op::RotateRight32:
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, true, Common::RotateRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU8()));
}
break;
case Op::RotateRight64:
FoldShifts(inst);
if (FoldShifts(inst)) {
ReplaceUsesWith(inst, false, Common::RotateRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU8()));
}
break;
case Op::LogicalShiftLeftMasked32:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() << (inst.GetArg(1).GetU32() & 0x1f));
}
break;
case Op::LogicalShiftLeftMasked64:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() << (inst.GetArg(1).GetU64() & 0x3f));
}
break;
case Op::LogicalShiftRightMasked32:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, true, inst.GetArg(0).GetU32() >> (inst.GetArg(1).GetU32() & 0x1f));
}
break;
case Op::LogicalShiftRightMasked64:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, false, inst.GetArg(0).GetU64() >> (inst.GetArg(1).GetU64() & 0x3f));
}
break;
case Op::ArithmeticShiftRightMasked32:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, true, static_cast<s32>(inst.GetArg(0).GetU32()) >> (inst.GetArg(1).GetU32() & 0x1f));
}
break;
case Op::ArithmeticShiftRightMasked64:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, false, static_cast<s64>(inst.GetArg(0).GetU64()) >> (inst.GetArg(1).GetU64() & 0x3f));
}
break;
case Op::RotateRightMasked32:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, true, Common::RotateRight<u32>(inst.GetArg(0).GetU32(), inst.GetArg(1).GetU32()));
}
break;
case Op::RotateRightMasked64:
if (inst.AreAllArgsImmediates()) {
ReplaceUsesWith(inst, false, Common::RotateRight<u64>(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU64()));
}
break;
case Op::Mul32:
case Op::Mul64: