diff --git a/src/ir_opt/constant_propagation_pass.cpp b/src/ir_opt/constant_propagation_pass.cpp index 21ab9146..0ff67bfb 100644 --- a/src/ir_opt/constant_propagation_pass.cpp +++ b/src/ir_opt/constant_propagation_pass.cpp @@ -73,6 +73,36 @@ bool FoldCommutative(IR::Inst& inst, bool is_32_bit, ImmFn imm_fn) { return true; } +void FoldAdd(IR::Inst& inst, bool is_32_bit) { + const auto lhs = inst.GetArg(0); + const auto rhs = inst.GetArg(1); + const auto carry = inst.GetArg(2); + + if (lhs.IsImmediate() && !rhs.IsImmediate()) { + // Normalize + inst.SetArg(0, rhs); + inst.SetArg(1, lhs); + FoldAdd(inst, is_32_bit); + return; + } + + if (!lhs.IsImmediate() && rhs.IsImmediate()) { + const IR::Inst* lhs_inst = lhs.GetInstRecursive(); + if (lhs_inst->GetOpcode() == inst.GetOpcode() && lhs_inst->GetArg(1).IsImmediate() && lhs_inst->GetArg(2).IsImmediate()) { + const u64 combined = rhs.GetImmediateAsU64() + lhs_inst->GetArg(1).GetImmediateAsU64() + lhs_inst->GetArg(2).GetU1(); + inst.SetArg(0, lhs_inst->GetArg(0)); + inst.SetArg(1, Value(is_32_bit, combined)); + return; + } + } + + if (inst.AreAllArgsImmediates() && !inst.HasAssociatedPseudoOperation()) { + const u64 result = lhs.GetImmediateAsU64() + rhs.GetImmediateAsU64() + carry.GetU1(); + ReplaceUsesWith(inst, is_32_bit, result); + return; + } +} + // Folds AND operations based on the following: // // 1. imm_x & imm_y -> result @@ -297,6 +327,19 @@ void FoldSignExtendXToLong(IR::Inst& inst) { inst.ReplaceUsesWith(IR::Value{static_cast(value)}); } +void FoldSub(IR::Inst& inst, bool is_32_bit) { + if (!inst.AreAllArgsImmediates() || inst.HasAssociatedPseudoOperation()) { + return; + } + + const auto lhs = inst.GetArg(0); + const auto rhs = inst.GetArg(1); + const auto carry = inst.GetArg(2); + + const u64 result = lhs.GetImmediateAsU64() + (~rhs.GetImmediateAsU64()) + carry.GetU1(); + ReplaceUsesWith(inst, is_32_bit, result); +} + void FoldZeroExtendXToWord(IR::Inst& inst) { if (!inst.AreAllArgsImmediates()) { return; @@ -426,6 +469,14 @@ void ConstantPropagation(IR::Block& block) { ReplaceUsesWith(inst, false, Common::RotateRight(inst.GetArg(0).GetU64(), inst.GetArg(1).GetU64())); } break; + case Op::Add32: + case Op::Add64: + FoldAdd(inst, opcode == Op::Add32); + break; + case Op::Sub32: + case Op::Sub64: + FoldSub(inst, opcode == Op::Sub32); + break; case Op::Mul32: case Op::Mul64: FoldMultiply(inst, opcode == Op::Mul32);