Use variant instead of creating an object for literals

This commit is contained in:
ReinUsesLisp 2018-10-28 13:44:12 -03:00
parent 8f8115d397
commit 00fc8daf56
12 changed files with 146 additions and 144 deletions

View file

@ -11,6 +11,7 @@
#include <optional>
#include <set>
#include <spirv/unified1/spirv.hpp11>
#include <variant>
#include <vector>
namespace Sirit {
@ -20,7 +21,9 @@ constexpr std::uint32_t GENERATOR_MAGIC_NUMBER = 0;
class Op;
class Operand;
typedef const Op* Ref;
using Literal = std::variant<std::uint32_t, std::uint64_t, std::int32_t,
std::int64_t, float, double>;
using Ref = const Op*;
class Module {
public:
@ -135,7 +138,7 @@ class Module {
Ref ConstantFalse(Ref result_type);
/// Returns a numeric scalar constant.
Ref Constant(Ref result_type, Operand* literal);
Ref Constant(Ref result_type, const Literal& literal);
/// Returns a numeric scalar constant.
Ref ConstantComposite(Ref result_type,
@ -201,18 +204,11 @@ class Module {
/// Add a decoration to target.
Ref Decorate(Ref target, spv::Decoration decoration,
const std::vector<Operand*>& literals = {});
const std::vector<Literal>& literals = {});
Ref MemberDecorate(Ref structure_type, Operand* member, spv::Decoration decoration,
const std::vector<Operand*>& literals = {});
// Literals
static Operand* Literal(std::uint32_t value);
static Operand* Literal(std::uint64_t value);
static Operand* Literal(std::int32_t value);
static Operand* Literal(std::int64_t value);
static Operand* Literal(float value);
static Operand* Literal(double value);
Ref MemberDecorate(Ref structure_type, Literal member,
spv::Decoration decoration,
const std::vector<Literal>& literals = {});
private:
Ref AddCode(Op* op);

View file

@ -7,7 +7,6 @@ add_library(sirit
stream.h
operand.cpp
operand.h
literal.cpp
literal-number.cpp
literal-number.h
literal-string.cpp

View file

@ -10,21 +10,22 @@
namespace Sirit {
Ref Module::Decorate(Ref target, spv::Decoration decoration,
const std::vector<Operand*>& literals) {
const std::vector<Literal>& literals) {
auto op{new Op(spv::Op::OpDecorate)};
op->Add(target);
AddEnum(op, decoration);
op->Sink(literals);
op->Add(literals);
return AddAnnotation(op);
}
Ref Module::MemberDecorate(Ref structure_type, Operand* member, spv::Decoration decoration,
const std::vector<Operand*>& literals) {
Ref Module::MemberDecorate(Ref structure_type, Literal member,
spv::Decoration decoration,
const std::vector<Literal>& literals) {
auto op{new Op(spv::Op::OpMemberDecorate)};
op->Add(structure_type);
op->Sink(member);
op->Add(member);
AddEnum(op, decoration);
op->Sink(literals);
op->Add(literals);
return AddAnnotation(op);
}

View file

@ -4,9 +4,9 @@
* Lesser General Public License version 2.1 or any later version.
*/
#include <cassert>
#include "sirit/sirit.h"
#include "insts.h"
#include "sirit/sirit.h"
#include <cassert>
namespace Sirit {
@ -18,20 +18,23 @@ Ref Module::ConstantFalse(Ref result_type) {
return AddDeclaration(new Op(spv::Op::OpConstantFalse, bound, result_type));
}
Ref Module::Constant(Ref result_type, Operand* literal) {
Ref Module::Constant(Ref result_type, const Literal& literal) {
auto op{new Op(spv::Op::OpConstant, bound, result_type)};
op->Add(literal);
return AddDeclaration(op);
}
Ref Module::ConstantComposite(Ref result_type, const std::vector<Ref>& constituents) {
Ref Module::ConstantComposite(Ref result_type,
const std::vector<Ref>& constituents) {
auto op{new Op(spv::Op::OpConstantComposite, bound, result_type)};
op->Add(constituents);
return AddDeclaration(op);
}
Ref Module::ConstantSampler(Ref result_type, spv::SamplerAddressingMode addressing_mode,
bool normalized, spv::SamplerFilterMode filter_mode) {
Ref Module::ConstantSampler(Ref result_type,
spv::SamplerAddressingMode addressing_mode,
bool normalized,
spv::SamplerFilterMode filter_mode) {
AddCapability(spv::Capability::LiteralSampler);
AddCapability(spv::Capability::Kernel);
auto op{new Op(spv::Op::OpConstantSampler, bound, result_type)};

View file

@ -4,8 +4,8 @@
* Lesser General Public License version 2.1 or any later version.
*/
#include "sirit/sirit.h"
#include "insts.h"
#include "sirit/sirit.h"
namespace Sirit {

View file

@ -4,12 +4,13 @@
* Lesser General Public License version 2.1 or any later version.
*/
#include "sirit/sirit.h"
#include "insts.h"
#include "sirit/sirit.h"
namespace Sirit {
Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask loop_control,
Ref Module::LoopMerge(Ref merge_block, Ref continue_target,
spv::LoopControlMask loop_control,
const std::vector<Ref>& literals) {
auto op{new Op(spv::Op::OpLoopMerge)};
op->Add(merge_block);
@ -19,16 +20,15 @@ Ref Module::LoopMerge(Ref merge_block, Ref continue_target, spv::LoopControlMask
return AddCode(op);
}
Ref Module::SelectionMerge(Ref merge_block, spv::SelectionControlMask selection_control) {
Ref Module::SelectionMerge(Ref merge_block,
spv::SelectionControlMask selection_control) {
auto op{new Op(spv::Op::OpSelectionMerge)};
op->Add(merge_block);
AddEnum(op, selection_control);
return AddCode(op);
}
Ref Module::Label() {
return AddCode(spv::Op::OpLabel, bound++);
}
Ref Module::Label() { return AddCode(spv::Op::OpLabel, bound++); }
Ref Module::Branch(Ref target_label) {
auto op{new Op(spv::Op::OpBranch)};
@ -37,20 +37,19 @@ Ref Module::Branch(Ref target_label) {
}
Ref Module::BranchConditional(Ref condition, Ref true_label, Ref false_label,
std::uint32_t true_weight, std::uint32_t false_weight) {
std::uint32_t true_weight,
std::uint32_t false_weight) {
auto op{new Op(spv::Op::OpBranchConditional)};
op->Add(condition);
op->Add(true_label);
op->Add(false_label);
if (true_weight != 0 || false_weight != 0) {
op->Add(Literal(true_weight));
op->Add(Literal(false_weight));
op->Add(true_weight);
op->Add(false_weight);
}
return AddCode(op);
}
Ref Module::Return() {
return AddCode(spv::Op::OpReturn);
}
Ref Module::Return() { return AddCode(spv::Op::OpReturn); }
} // namespace Sirit

View file

@ -4,20 +4,19 @@
* Lesser General Public License version 2.1 or any later version.
*/
#include "sirit/sirit.h"
#include "insts.h"
#include "sirit/sirit.h"
namespace Sirit {
Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control, Ref function_type) {
Ref Module::Function(Ref result_type, spv::FunctionControlMask function_control,
Ref function_type) {
auto op{new Op{spv::Op::OpFunction, bound++, result_type}};
op->Add(static_cast<u32>(function_control));
op->Add(function_type);
return AddCode(op);
}
Ref Module::FunctionEnd() {
return AddCode(spv::Op::OpFunctionEnd);
}
Ref Module::FunctionEnd() { return AddCode(spv::Op::OpFunctionEnd); }
} // namespace Sirit

View file

@ -7,8 +7,8 @@
#include <cassert>
#include <optional>
#include "sirit/sirit.h"
#include "insts.h"
#include "sirit/sirit.h"
namespace Sirit {
@ -62,68 +62,68 @@ Ref Module::TypeMatrix(Ref column_type, int column_count) {
return AddDeclaration(op);
}
Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed, bool ms,
int sampled, spv::ImageFormat image_format,
Ref Module::TypeImage(Ref sampled_type, spv::Dim dim, int depth, bool arrayed,
bool ms, int sampled, spv::ImageFormat image_format,
std::optional<spv::AccessQualifier> access_qualifier) {
switch (dim) {
case spv::Dim::Dim1D:
AddCapability(spv::Capability::Sampled1D);
break;
case spv::Dim::Cube:
AddCapability(spv::Capability::Shader);
break;
case spv::Dim::Rect:
AddCapability(spv::Capability::SampledRect);
break;
case spv::Dim::Buffer:
AddCapability(spv::Capability::SampledBuffer);
break;
case spv::Dim::SubpassData:
AddCapability(spv::Capability::InputAttachment);
break;
case spv::Dim::Dim1D:
AddCapability(spv::Capability::Sampled1D);
break;
case spv::Dim::Cube:
AddCapability(spv::Capability::Shader);
break;
case spv::Dim::Rect:
AddCapability(spv::Capability::SampledRect);
break;
case spv::Dim::Buffer:
AddCapability(spv::Capability::SampledBuffer);
break;
case spv::Dim::SubpassData:
AddCapability(spv::Capability::InputAttachment);
break;
}
switch (image_format) {
case spv::ImageFormat::Rgba32f:
case spv::ImageFormat::Rgba16f:
case spv::ImageFormat::R32f:
case spv::ImageFormat::Rgba8:
case spv::ImageFormat::Rgba8Snorm:
case spv::ImageFormat::Rgba32i:
case spv::ImageFormat::Rgba16i:
case spv::ImageFormat::Rgba8i:
case spv::ImageFormat::R32i:
case spv::ImageFormat::Rgba32ui:
case spv::ImageFormat::Rgba16ui:
case spv::ImageFormat::Rgba8ui:
case spv::ImageFormat::R32ui:
AddCapability(spv::Capability::Shader);
break;
case spv::ImageFormat::Rg32f:
case spv::ImageFormat::Rg16f:
case spv::ImageFormat::R11fG11fB10f:
case spv::ImageFormat::R16f:
case spv::ImageFormat::Rgba16:
case spv::ImageFormat::Rgb10A2:
case spv::ImageFormat::Rg16:
case spv::ImageFormat::Rg8:
case spv::ImageFormat::R16:
case spv::ImageFormat::R8:
case spv::ImageFormat::Rgba16Snorm:
case spv::ImageFormat::Rg16Snorm:
case spv::ImageFormat::Rg8Snorm:
case spv::ImageFormat::Rg32i:
case spv::ImageFormat::Rg16i:
case spv::ImageFormat::Rg8i:
case spv::ImageFormat::R16i:
case spv::ImageFormat::R8i:
case spv::ImageFormat::Rgb10a2ui:
case spv::ImageFormat::Rg32ui:
case spv::ImageFormat::Rg16ui:
case spv::ImageFormat::Rg8ui:
case spv::ImageFormat::R16ui:
case spv::ImageFormat::R8ui:
AddCapability(spv::Capability::StorageImageExtendedFormats);
break;
case spv::ImageFormat::Rgba32f:
case spv::ImageFormat::Rgba16f:
case spv::ImageFormat::R32f:
case spv::ImageFormat::Rgba8:
case spv::ImageFormat::Rgba8Snorm:
case spv::ImageFormat::Rgba32i:
case spv::ImageFormat::Rgba16i:
case spv::ImageFormat::Rgba8i:
case spv::ImageFormat::R32i:
case spv::ImageFormat::Rgba32ui:
case spv::ImageFormat::Rgba16ui:
case spv::ImageFormat::Rgba8ui:
case spv::ImageFormat::R32ui:
AddCapability(spv::Capability::Shader);
break;
case spv::ImageFormat::Rg32f:
case spv::ImageFormat::Rg16f:
case spv::ImageFormat::R11fG11fB10f:
case spv::ImageFormat::R16f:
case spv::ImageFormat::Rgba16:
case spv::ImageFormat::Rgb10A2:
case spv::ImageFormat::Rg16:
case spv::ImageFormat::Rg8:
case spv::ImageFormat::R16:
case spv::ImageFormat::R8:
case spv::ImageFormat::Rgba16Snorm:
case spv::ImageFormat::Rg16Snorm:
case spv::ImageFormat::Rg8Snorm:
case spv::ImageFormat::Rg32i:
case spv::ImageFormat::Rg16i:
case spv::ImageFormat::Rg8i:
case spv::ImageFormat::R16i:
case spv::ImageFormat::R8i:
case spv::ImageFormat::Rgb10a2ui:
case spv::ImageFormat::Rg32ui:
case spv::ImageFormat::Rg16ui:
case spv::ImageFormat::Rg8ui:
case spv::ImageFormat::R16ui:
case spv::ImageFormat::R8ui:
AddCapability(spv::Capability::StorageImageExtendedFormats);
break;
}
auto op{new Op(spv::Op::OpTypeImage, bound)};
op->Add(sampled_type);
@ -179,19 +179,19 @@ Ref Module::TypeOpaque(const std::string& name) {
Ref Module::TypePointer(spv::StorageClass storage_class, Ref type) {
switch (storage_class) {
case spv::StorageClass::Uniform:
case spv::StorageClass::Output:
case spv::StorageClass::Private:
case spv::StorageClass::PushConstant:
case spv::StorageClass::StorageBuffer:
AddCapability(spv::Capability::Shader);
break;
case spv::StorageClass::Generic:
AddCapability(spv::Capability::GenericPointer);
break;
case spv::StorageClass::AtomicCounter:
AddCapability(spv::Capability::AtomicStorage);
break;
case spv::StorageClass::Uniform:
case spv::StorageClass::Output:
case spv::StorageClass::Private:
case spv::StorageClass::PushConstant:
case spv::StorageClass::StorageBuffer:
AddCapability(spv::Capability::Shader);
break;
case spv::StorageClass::Generic:
AddCapability(spv::Capability::GenericPointer);
break;
case spv::StorageClass::AtomicCounter:
AddCapability(spv::Capability::AtomicStorage);
break;
}
auto op{new Op(spv::Op::OpTypePointer, bound)};
op->Add(static_cast<u32>(storage_class));

View file

@ -1,26 +0,0 @@
/* This file is part of the sirit project.
* Copyright (c) 2018 ReinUsesLisp
* This software may be used and distributed according to the terms of the GNU
* Lesser General Public License version 2.1 or any later version.
*/
#include "common_types.h"
#include "literal-number.h"
#include "operand.h"
#include "sirit/sirit.h"
namespace Sirit {
#define DEFINE_LITERAL(type) \
Operand* Module::Literal(type value) { \
return LiteralNumber::Create<type>(value); \
}
DEFINE_LITERAL(u32)
DEFINE_LITERAL(u64)
DEFINE_LITERAL(s32)
DEFINE_LITERAL(s64)
DEFINE_LITERAL(f32)
DEFINE_LITERAL(f64)
} // namespace Sirit

View file

@ -71,6 +71,34 @@ void Op::Sink(const std::vector<Operand*>& operands) {
}
}
void Op::Add(const Literal& literal) {
Operand* operand = [&]() {
switch (literal.index()) {
case 0:
return LiteralNumber::Create(std::get<0>(literal));
case 1:
return LiteralNumber::Create(std::get<1>(literal));
case 2:
return LiteralNumber::Create(std::get<2>(literal));
case 3:
return LiteralNumber::Create(std::get<3>(literal));
case 4:
return LiteralNumber::Create(std::get<4>(literal));
case 5:
return LiteralNumber::Create(std::get<5>(literal));
default:
assert(!"invalid literal type");
}
}();
Sink(operand);
}
void Op::Add(const std::vector<Literal>& literals) {
for (const auto& literal : literals) {
Add(literal);
}
}
void Op::Add(const Operand* operand) { operands.push_back(operand); }
void Op::Add(u32 integer) { Sink(LiteralNumber::Create<u32>(integer)); }

View file

@ -31,6 +31,10 @@ class Op : public Operand {
void Sink(const std::vector<Operand*>& operands);
void Add(const Literal& literal);
void Add(const std::vector<Literal>& literals);
void Add(const Operand* operand);
void Add(u32 integer);

View file

@ -20,8 +20,7 @@ static void WriteEnum(Stream& stream, spv::Op opcode, T value) {
op.Write(stream);
}
template <typename T>
static void WriteSet(Stream& stream, const T& set) {
template <typename T> static void WriteSet(Stream& stream, const T& set) {
for (const auto& item : set) {
item->Write(stream);
}