diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h index eee5cff..35b8fd8 100644 --- a/include/sirit/sirit.h +++ b/include/sirit/sirit.h @@ -933,7 +933,8 @@ private: std::vector> execution_modes; std::vector> debug; std::vector> annotations; - std::vector> declarations; + std::unordered_set> declarations; + std::vector sorted_declarations; std::vector global_variables; diff --git a/src/literal_number.cpp b/src/literal_number.cpp index 9a1b4c7..036d87d 100644 --- a/src/literal_number.cpp +++ b/src/literal_number.cpp @@ -35,4 +35,8 @@ bool LiteralNumber::operator==(const Operand& other) const { return false; } +std::size_t LiteralNumber::Hash() const { + return static_cast(raw) ^ Operand::Hash(); +} + } // namespace Sirit diff --git a/src/literal_number.h b/src/literal_number.h index 5809ca5..abc2741 100644 --- a/src/literal_number.h +++ b/src/literal_number.h @@ -6,6 +6,7 @@ #pragma once +#include #include #include #include "operand.h" @@ -23,6 +24,8 @@ public: bool operator==(const Operand& other) const override; + std::size_t Hash() const override; + template static LiteralNumber* Create(T value) { static_assert(sizeof(T) == 4 || sizeof(T) == 8); diff --git a/src/literal_string.cpp b/src/literal_string.cpp index 23fa111..d340590 100644 --- a/src/literal_string.cpp +++ b/src/literal_string.cpp @@ -4,6 +4,7 @@ * Lesser General Public License version 3 or any later version. */ +#include #include #include "common_types.h" #include "literal_string.h" @@ -36,4 +37,8 @@ bool LiteralString::operator==(const Operand& other) const { return false; } +std::size_t LiteralString::Hash() const { + return Operand::Hash() ^ std::hash{}(string); +} + } // namespace Sirit diff --git a/src/literal_string.h b/src/literal_string.h index 92a8a28..048d58a 100644 --- a/src/literal_string.h +++ b/src/literal_string.h @@ -6,6 +6,7 @@ #pragma once +#include #include #include "operand.h" #include "stream.h" @@ -22,8 +23,10 @@ public: bool operator==(const Operand& other) const override; + std::size_t Hash() const override; + private: - const std::string string; + std::string string; }; } // namespace Sirit diff --git a/src/op.cpp b/src/op.cpp index 097a67d..8c181c5 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -5,6 +5,7 @@ */ #include +#include #include "common_types.h" #include "literal_number.h" @@ -47,6 +48,21 @@ bool Op::operator==(const Operand& other) const { return false; } +std::size_t Op::Hash() const { + std::size_t hash = Operand::Hash(); + hash ^= static_cast(opcode) << 20; + if (result_type) { + hash ^= result_type->Hash() << 16; + } + hash ^= static_cast(id.value_or(0)) << 8; + std::size_t wrap = 32; + for (const auto operand : operands) { + wrap = (wrap + 7) % (sizeof(std::size_t) * CHAR_BIT); + hash ^= operand->Hash() << wrap; + } + return hash; +} + void Op::Write(Stream& stream) const { stream.Write(static_cast(opcode)); stream.Write(WordCount()); diff --git a/src/op.h b/src/op.h index 4b63d66..5b010f0 100644 --- a/src/op.h +++ b/src/op.h @@ -6,6 +6,7 @@ #pragma once +#include #include #include "common_types.h" #include "operand.h" @@ -24,6 +25,8 @@ public: bool operator==(const Operand& other) const override; + std::size_t Hash() const override; + void Write(Stream& stream) const; void Sink(Operand* operand); diff --git a/src/operand.cpp b/src/operand.cpp index 954a45d..f45ca3a 100644 --- a/src/operand.cpp +++ b/src/operand.cpp @@ -30,6 +30,10 @@ bool Operand::operator!=(const Operand& other) const { return !(*this == other); } +std::size_t Operand::Hash() const { + return static_cast(operand_type) << 30; +} + OperandType Operand::GetType() const { return operand_type; } diff --git a/src/operand.h b/src/operand.h index 70ba810..f27bc94 100644 --- a/src/operand.h +++ b/src/operand.h @@ -6,6 +6,8 @@ #pragma once +#include +#include "common_types.h" #include "stream.h" namespace Sirit { @@ -23,6 +25,8 @@ public: virtual bool operator==(const Operand& other) const; bool operator!=(const Operand& other) const; + virtual std::size_t Hash() const; + OperandType GetType() const; protected: @@ -30,3 +34,14 @@ protected: }; } // namespace Sirit + +namespace std { + +template <> +struct hash { + std::size_t operator()(const Sirit::Operand& operand) const noexcept { + return operand.Hash(); + } +}; + +} // namespace std diff --git a/src/sirit.cpp b/src/sirit.cpp index dcdfe23..c955c8c 100644 --- a/src/sirit.cpp +++ b/src/sirit.cpp @@ -59,7 +59,7 @@ std::vector Module::Assemble() const { WriteSet(stream, execution_modes); WriteSet(stream, debug); WriteSet(stream, annotations); - WriteSet(stream, declarations); + WriteSet(stream, sorted_declarations); WriteSet(stream, global_variables); WriteSet(stream, code); @@ -74,13 +74,14 @@ void Module::AddCapability(spv::Capability capability) { capabilities.insert(capability); } -void Module::SetMemoryModel(spv::AddressingModel addressing_model_, spv::MemoryModel memory_model_) { +void Module::SetMemoryModel(spv::AddressingModel addressing_model_, + spv::MemoryModel memory_model_) { this->addressing_model = addressing_model_; this->memory_model = memory_model_; } -void Module::AddEntryPoint(spv::ExecutionModel execution_model, Id entry_point, - std::string name, const std::vector& interfaces) { +void Module::AddEntryPoint(spv::ExecutionModel execution_model, Id entry_point, std::string name, + const std::vector& interfaces) { auto op{std::make_unique(spv::Op::OpEntryPoint)}; op->Add(static_cast(execution_model)); op->Add(entry_point); @@ -121,14 +122,12 @@ Id Module::AddCode(spv::Op opcode, std::optional id) { } Id Module::AddDeclaration(std::unique_ptr op) { - const auto& found{std::find_if(declarations.begin(), declarations.end(), - [&op](const auto& other) { return *other == *op; })}; - if (found != declarations.end()) { - return found->get(); + const auto [it, is_inserted] = declarations.emplace(std::move(op)); + const Id id = it->get(); + if (is_inserted) { + sorted_declarations.push_back(id); + ++bound; } - const auto id = op.get(); - declarations.push_back(std::move(op)); - bound++; return id; }