diff --git a/include/sirit/sirit.h b/include/sirit/sirit.h index fadf0a9..ccc06e4 100644 --- a/include/sirit/sirit.h +++ b/include/sirit/sirit.h @@ -37,7 +37,7 @@ public: * externally. * @return A stream of bytes representing a SPIR-V module. */ - std::vector Assemble() const; + std::vector Assemble() const; /// Adds a SPIR-V extension. void AddExtension(std::string extension_name); diff --git a/src/literal_string.cpp b/src/literal_string.cpp index e7f85e6..7a820d3 100644 --- a/src/literal_string.cpp +++ b/src/literal_string.cpp @@ -17,12 +17,7 @@ LiteralString::LiteralString(std::string string) : string{std::move(string)} { LiteralString::~LiteralString() = default; void LiteralString::Fetch(Stream& stream) const { - for (std::size_t i = 0; i < string.size(); i++) { - stream.Write(static_cast(string[i])); - } - for (std::size_t i = 0; i < 4 - (string.size() % 4); i++) { - stream.Write(static_cast(0)); - } + stream.Write(string); } u16 LiteralString::GetWordCount() const { diff --git a/src/op.cpp b/src/op.cpp index 83895ae..974f0c3 100644 --- a/src/op.cpp +++ b/src/op.cpp @@ -49,8 +49,7 @@ bool Op::operator==(const Operand& other) const { } void Op::Write(Stream& stream) const { - stream.Write(static_cast(opcode)); - stream.Write(WordCount()); + stream.Write(static_cast(opcode), WordCount()); if (result_type) { result_type->Fetch(stream); diff --git a/src/sirit.cpp b/src/sirit.cpp index c88111c..6e28bdf 100644 --- a/src/sirit.cpp +++ b/src/sirit.cpp @@ -24,8 +24,8 @@ Module::Module(u32 version) : version(version) {} Module::~Module() = default; -std::vector Module::Assemble() const { - std::vector bytes; +std::vector Module::Assemble() const { + std::vector bytes; Stream stream{bytes}; stream.Write(spv::MagicNumber); diff --git a/src/stream.cpp b/src/stream.cpp index bde6534..2282364 100644 --- a/src/stream.cpp +++ b/src/stream.cpp @@ -8,36 +8,42 @@ namespace Sirit { -Stream::Stream(std::vector& bytes) : bytes(bytes) {} +Stream::Stream(std::vector& words) : words(words) {} Stream::~Stream() = default; void Stream::Write(std::string_view string) { - bytes.insert(bytes.end(), string.begin(), string.end()); + constexpr std::size_t word_size = 4; + const auto size = string.size(); + auto read = [string, size](std::size_t offset) { return offset < size ? string[offset] : 0; }; - const auto size{string.size()}; - for (std::size_t i = 0; i < 4 - size % 4; i++) { - Write(static_cast(0)); + words.reserve(words.size() + size / word_size + 1); + for (std::size_t i = 0; i < size; i += word_size) { + Write(read(i), read(i + 1), read(i + 2), read(i + 3)); + } + if (size % word_size == 0) { + Write(u32(0)); } } void Stream::Write(u64 value) { - const auto* const mem = reinterpret_cast(&value); - bytes.insert(bytes.end(), mem, mem + sizeof(u64)); + const u32 dword[] = {static_cast(value), static_cast(value >> 32)}; + words.insert(std::begin(words), std::cbegin(dword), std::cend(dword)); } void Stream::Write(u32 value) { - const auto* const mem = reinterpret_cast(&value); - bytes.insert(bytes.end(), mem, mem + sizeof(u32)); + words.push_back(value); } -void Stream::Write(u16 value) { - const auto* const mem{reinterpret_cast(&value)}; - bytes.insert(bytes.end(), mem, mem + sizeof(u16)); +void Stream::Write(u16 first, u16 second) { + const u32 word = static_cast(first) | static_cast(second) << 16; + Write(word); } -void Stream::Write(u8 value) { - bytes.push_back(value); +void Stream::Write(u8 first, u8 second, u8 third, u8 fourth) { + const u32 word = static_cast(first) | static_cast(second) << 8 | + static_cast(third) << 16 | static_cast(fourth) << 24; + Write(word); } } // namespace Sirit diff --git a/src/stream.h b/src/stream.h index bb6c293..b7ec872 100644 --- a/src/stream.h +++ b/src/stream.h @@ -14,7 +14,7 @@ namespace Sirit { class Stream { public: - explicit Stream(std::vector& bytes); + explicit Stream(std::vector& words); ~Stream(); void Write(std::string_view string); @@ -23,12 +23,12 @@ public: void Write(u32 value); - void Write(u16 value); + void Write(u16 first, u16 second); - void Write(u8 value); + void Write(u8 first, u8 second, u8 third, u8 fourth); private: - std::vector& bytes; + std::vector& words; }; } // namespace Sirit diff --git a/tests/main.cpp b/tests/main.cpp index bff8359..e004e59 100644 --- a/tests/main.cpp +++ b/tests/main.cpp @@ -79,10 +79,10 @@ int main(int argc, char** argv) { MyModule module; module.Generate(); - std::vector code{module.Assemble()}; + std::vector code{module.Assemble()}; FILE* file = fopen("sirit.spv", "wb"); - fwrite(code.data(), 1, code.size(), file); + fwrite(code.data(), sizeof(std::uint32_t), code.size(), file); fclose(file); return 0;