forked from suyu/suyu
shader: Add subgroup masks
This commit is contained in:
parent
fc93bc2abd
commit
da6cf2632c
10 changed files with 169 additions and 45 deletions
|
@ -390,8 +390,16 @@ void EmitContext::DefineInputs(const Info& info) {
|
|||
if (info.uses_local_invocation_id) {
|
||||
local_invocation_id = DefineInput(*this, U32[3], spv::BuiltIn::LocalInvocationId);
|
||||
}
|
||||
if (info.uses_subgroup_mask) {
|
||||
subgroup_mask_eq = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupEqMaskKHR);
|
||||
subgroup_mask_lt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLtMaskKHR);
|
||||
subgroup_mask_le = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupLeMaskKHR);
|
||||
subgroup_mask_gt = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGtMaskKHR);
|
||||
subgroup_mask_ge = DefineInput(*this, U32[4], spv::BuiltIn::SubgroupGeMaskKHR);
|
||||
}
|
||||
if (info.uses_subgroup_invocation_id ||
|
||||
(profile.warp_size_potentially_larger_than_guest && info.uses_subgroup_vote)) {
|
||||
(profile.warp_size_potentially_larger_than_guest &&
|
||||
(info.uses_subgroup_vote || info.uses_subgroup_mask))) {
|
||||
subgroup_local_invocation_id =
|
||||
DefineInput(*this, U32[1], spv::BuiltIn::SubgroupLocalInvocationId);
|
||||
}
|
||||
|
|
|
@ -97,6 +97,11 @@ public:
|
|||
Id workgroup_id{};
|
||||
Id local_invocation_id{};
|
||||
Id subgroup_local_invocation_id{};
|
||||
Id subgroup_mask_eq{};
|
||||
Id subgroup_mask_lt{};
|
||||
Id subgroup_mask_le{};
|
||||
Id subgroup_mask_gt{};
|
||||
Id subgroup_mask_ge{};
|
||||
Id instance_id{};
|
||||
Id instance_index{};
|
||||
Id base_instance{};
|
||||
|
|
|
@ -401,6 +401,11 @@ Id EmitVoteAll(EmitContext& ctx, Id pred);
|
|||
Id EmitVoteAny(EmitContext& ctx, Id pred);
|
||||
Id EmitVoteEqual(EmitContext& ctx, Id pred);
|
||||
Id EmitSubgroupBallot(EmitContext& ctx, Id pred);
|
||||
Id EmitSubgroupEqMask(EmitContext& ctx);
|
||||
Id EmitSubgroupLtMask(EmitContext& ctx);
|
||||
Id EmitSubgroupLeMask(EmitContext& ctx);
|
||||
Id EmitSubgroupGtMask(EmitContext& ctx);
|
||||
Id EmitSubgroupGeMask(EmitContext& ctx);
|
||||
Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,
|
||||
Id segmentation_mask);
|
||||
Id EmitShuffleUp(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,
|
||||
|
|
|
@ -6,10 +6,18 @@
|
|||
|
||||
namespace Shader::Backend::SPIRV {
|
||||
namespace {
|
||||
Id LargeWarpBallot(EmitContext& ctx, Id ballot) {
|
||||
Id WarpExtract(EmitContext& ctx, Id value) {
|
||||
const Id shift{ctx.Constant(ctx.U32[1], 5)};
|
||||
const Id local_index{ctx.OpLoad(ctx.U32[1], ctx.subgroup_local_invocation_id)};
|
||||
return ctx.OpVectorExtractDynamic(ctx.U32[1], ballot, local_index);
|
||||
return ctx.OpVectorExtractDynamic(ctx.U32[1], value, local_index);
|
||||
}
|
||||
|
||||
Id LoadMask(EmitContext& ctx, Id mask) {
|
||||
const Id value{ctx.OpLoad(ctx.U32[4], mask)};
|
||||
if (!ctx.profile.warp_size_potentially_larger_than_guest) {
|
||||
return ctx.OpCompositeExtract(ctx.U32[1], value, 0U);
|
||||
}
|
||||
return WarpExtract(ctx, value);
|
||||
}
|
||||
|
||||
void SetInBoundsFlag(IR::Inst* inst, Id result) {
|
||||
|
@ -47,8 +55,8 @@ Id EmitVoteAll(EmitContext& ctx, Id pred) {
|
|||
return ctx.OpSubgroupAllKHR(ctx.U1, pred);
|
||||
}
|
||||
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
|
||||
const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
|
||||
const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
|
||||
const Id active_mask{WarpExtract(ctx, mask_ballot)};
|
||||
const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
|
||||
const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)};
|
||||
return ctx.OpIEqual(ctx.U1, lhs, active_mask);
|
||||
}
|
||||
|
@ -58,8 +66,8 @@ Id EmitVoteAny(EmitContext& ctx, Id pred) {
|
|||
return ctx.OpSubgroupAnyKHR(ctx.U1, pred);
|
||||
}
|
||||
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
|
||||
const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
|
||||
const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
|
||||
const Id active_mask{WarpExtract(ctx, mask_ballot)};
|
||||
const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
|
||||
const Id lhs{ctx.OpBitwiseAnd(ctx.U32[1], ballot, active_mask)};
|
||||
return ctx.OpINotEqual(ctx.U1, lhs, ctx.u32_zero_value);
|
||||
}
|
||||
|
@ -69,8 +77,8 @@ Id EmitVoteEqual(EmitContext& ctx, Id pred) {
|
|||
return ctx.OpSubgroupAllEqualKHR(ctx.U1, pred);
|
||||
}
|
||||
const Id mask_ballot{ctx.OpSubgroupBallotKHR(ctx.U32[4], ctx.true_value)};
|
||||
const Id active_mask{LargeWarpBallot(ctx, mask_ballot)};
|
||||
const Id ballot{LargeWarpBallot(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
|
||||
const Id active_mask{WarpExtract(ctx, mask_ballot)};
|
||||
const Id ballot{WarpExtract(ctx, ctx.OpSubgroupBallotKHR(ctx.U32[4], pred))};
|
||||
const Id lhs{ctx.OpBitwiseXor(ctx.U32[1], ballot, active_mask)};
|
||||
return ctx.OpLogicalOr(ctx.U1, ctx.OpIEqual(ctx.U1, lhs, ctx.u32_zero_value),
|
||||
ctx.OpIEqual(ctx.U1, lhs, active_mask));
|
||||
|
@ -81,7 +89,27 @@ Id EmitSubgroupBallot(EmitContext& ctx, Id pred) {
|
|||
if (!ctx.profile.warp_size_potentially_larger_than_guest) {
|
||||
return ctx.OpCompositeExtract(ctx.U32[1], ballot, 0U);
|
||||
}
|
||||
return LargeWarpBallot(ctx, ballot);
|
||||
return WarpExtract(ctx, ballot);
|
||||
}
|
||||
|
||||
Id EmitSubgroupEqMask(EmitContext& ctx) {
|
||||
return LoadMask(ctx, ctx.subgroup_mask_eq);
|
||||
}
|
||||
|
||||
Id EmitSubgroupLtMask(EmitContext& ctx) {
|
||||
return LoadMask(ctx, ctx.subgroup_mask_lt);
|
||||
}
|
||||
|
||||
Id EmitSubgroupLeMask(EmitContext& ctx) {
|
||||
return LoadMask(ctx, ctx.subgroup_mask_le);
|
||||
}
|
||||
|
||||
Id EmitSubgroupGtMask(EmitContext& ctx) {
|
||||
return LoadMask(ctx, ctx.subgroup_mask_gt);
|
||||
}
|
||||
|
||||
Id EmitSubgroupGeMask(EmitContext& ctx) {
|
||||
return LoadMask(ctx, ctx.subgroup_mask_ge);
|
||||
}
|
||||
|
||||
Id EmitShuffleIndex(EmitContext& ctx, IR::Inst* inst, Id value, Id index, Id clamp,
|
||||
|
|
|
@ -1628,6 +1628,26 @@ U32 IREmitter::SubgroupBallot(const U1& value) {
|
|||
return Inst<U32>(Opcode::SubgroupBallot, value);
|
||||
}
|
||||
|
||||
U32 IREmitter::SubgroupEqMask() {
|
||||
return Inst<U32>(Opcode::SubgroupEqMask);
|
||||
}
|
||||
|
||||
U32 IREmitter::SubgroupLtMask() {
|
||||
return Inst<U32>(Opcode::SubgroupLtMask);
|
||||
}
|
||||
|
||||
U32 IREmitter::SubgroupLeMask() {
|
||||
return Inst<U32>(Opcode::SubgroupLeMask);
|
||||
}
|
||||
|
||||
U32 IREmitter::SubgroupGtMask() {
|
||||
return Inst<U32>(Opcode::SubgroupGtMask);
|
||||
}
|
||||
|
||||
U32 IREmitter::SubgroupGeMask() {
|
||||
return Inst<U32>(Opcode::SubgroupGeMask);
|
||||
}
|
||||
|
||||
U32 IREmitter::ShuffleIndex(const IR::U32& value, const IR::U32& index, const IR::U32& clamp,
|
||||
const IR::U32& seg_mask) {
|
||||
return Inst<U32>(Opcode::ShuffleIndex, value, index, clamp, seg_mask);
|
||||
|
|
|
@ -281,6 +281,11 @@ public:
|
|||
[[nodiscard]] U1 VoteAny(const U1& value);
|
||||
[[nodiscard]] U1 VoteEqual(const U1& value);
|
||||
[[nodiscard]] U32 SubgroupBallot(const U1& value);
|
||||
[[nodiscard]] U32 SubgroupEqMask();
|
||||
[[nodiscard]] U32 SubgroupLtMask();
|
||||
[[nodiscard]] U32 SubgroupLeMask();
|
||||
[[nodiscard]] U32 SubgroupGtMask();
|
||||
[[nodiscard]] U32 SubgroupGeMask();
|
||||
[[nodiscard]] U32 ShuffleIndex(const IR::U32& value, const IR::U32& index, const IR::U32& clamp,
|
||||
const IR::U32& seg_mask);
|
||||
[[nodiscard]] U32 ShuffleUp(const IR::U32& value, const IR::U32& index, const IR::U32& clamp,
|
||||
|
|
|
@ -417,6 +417,11 @@ OPCODE(VoteAll, U1, U1,
|
|||
OPCODE(VoteAny, U1, U1, )
|
||||
OPCODE(VoteEqual, U1, U1, )
|
||||
OPCODE(SubgroupBallot, U32, U1, )
|
||||
OPCODE(SubgroupEqMask, U32, )
|
||||
OPCODE(SubgroupLtMask, U32, )
|
||||
OPCODE(SubgroupLeMask, U32, )
|
||||
OPCODE(SubgroupGtMask, U32, )
|
||||
OPCODE(SubgroupGeMask, U32, )
|
||||
OPCODE(ShuffleIndex, U32, U32, U32, U32, U32, )
|
||||
OPCODE(ShuffleUp, U32, U32, U32, U32, U32, )
|
||||
OPCODE(ShuffleDown, U32, U32, U32, U32, U32, )
|
||||
|
|
|
@ -10,6 +10,7 @@ namespace Shader::Maxwell {
|
|||
namespace {
|
||||
enum class SpecialRegister : u64 {
|
||||
SR_LANEID = 0,
|
||||
SR_CLOCK = 1,
|
||||
SR_VIRTCFG = 2,
|
||||
SR_VIRTID = 3,
|
||||
SR_PM0 = 4,
|
||||
|
@ -20,6 +21,9 @@ enum class SpecialRegister : u64 {
|
|||
SR_PM5 = 9,
|
||||
SR_PM6 = 10,
|
||||
SR_PM7 = 11,
|
||||
SR12 = 12,
|
||||
SR13 = 13,
|
||||
SR14 = 14,
|
||||
SR_ORDERING_TICKET = 15,
|
||||
SR_PRIM_TYPE = 16,
|
||||
SR_INVOCATION_ID = 17,
|
||||
|
@ -41,44 +45,70 @@ enum class SpecialRegister : u64 {
|
|||
SR_TID_X = 33,
|
||||
SR_TID_Y = 34,
|
||||
SR_TID_Z = 35,
|
||||
SR_CTA_PARAM = 36,
|
||||
SR_CTAID_X = 37,
|
||||
SR_CTAID_Y = 38,
|
||||
SR_CTAID_Z = 39,
|
||||
SR_NTID = 49,
|
||||
SR_CirQueueIncrMinusOne = 50,
|
||||
SR_NLATC = 51,
|
||||
SR_SWINLO = 57,
|
||||
SR_SWINSZ = 58,
|
||||
SR_SMEMSZ = 59,
|
||||
SR_SMEMBANKS = 60,
|
||||
SR_LWINLO = 61,
|
||||
SR_LWINSZ = 62,
|
||||
SR_LMEMLOSZ = 63,
|
||||
SR_LMEMHIOFF = 64,
|
||||
SR_EQMASK = 65,
|
||||
SR_LTMASK = 66,
|
||||
SR_LEMASK = 67,
|
||||
SR_GTMASK = 68,
|
||||
SR_GEMASK = 69,
|
||||
SR_REGALLOC = 70,
|
||||
SR_GLOBALERRORSTATUS = 73,
|
||||
SR_WARPERRORSTATUS = 75,
|
||||
SR_PM_HI0 = 81,
|
||||
SR_PM_HI1 = 82,
|
||||
SR_PM_HI2 = 83,
|
||||
SR_PM_HI3 = 84,
|
||||
SR_PM_HI4 = 85,
|
||||
SR_PM_HI5 = 86,
|
||||
SR_PM_HI6 = 87,
|
||||
SR_PM_HI7 = 88,
|
||||
SR_CLOCKLO = 89,
|
||||
SR_CLOCKHI = 90,
|
||||
SR_GLOBALTIMERLO = 91,
|
||||
SR_GLOBALTIMERHI = 92,
|
||||
SR_HWTASKID = 105,
|
||||
SR_CIRCULARQUEUEENTRYINDEX = 106,
|
||||
SR_CIRCULARQUEUEENTRYADDRESSLOW = 107,
|
||||
SR_CIRCULARQUEUEENTRYADDRESSHIGH = 108,
|
||||
SR_NTID = 40,
|
||||
SR_CirQueueIncrMinusOne = 41,
|
||||
SR_NLATC = 42,
|
||||
SR43 = 43,
|
||||
SR_SM_SPA_VERSION = 44,
|
||||
SR_MULTIPASSSHADERINFO = 45,
|
||||
SR_LWINHI = 46,
|
||||
SR_SWINHI = 47,
|
||||
SR_SWINLO = 48,
|
||||
SR_SWINSZ = 49,
|
||||
SR_SMEMSZ = 50,
|
||||
SR_SMEMBANKS = 51,
|
||||
SR_LWINLO = 52,
|
||||
SR_LWINSZ = 53,
|
||||
SR_LMEMLOSZ = 54,
|
||||
SR_LMEMHIOFF = 55,
|
||||
SR_EQMASK = 56,
|
||||
SR_LTMASK = 57,
|
||||
SR_LEMASK = 58,
|
||||
SR_GTMASK = 59,
|
||||
SR_GEMASK = 60,
|
||||
SR_REGALLOC = 61,
|
||||
SR_BARRIERALLOC = 62,
|
||||
SR63 = 63,
|
||||
SR_GLOBALERRORSTATUS = 64,
|
||||
SR65 = 65,
|
||||
SR_WARPERRORSTATUS = 66,
|
||||
SR_WARPERRORSTATUSCLEAR = 67,
|
||||
SR68 = 68,
|
||||
SR69 = 69,
|
||||
SR70 = 70,
|
||||
SR71 = 71,
|
||||
SR_PM_HI0 = 72,
|
||||
SR_PM_HI1 = 73,
|
||||
SR_PM_HI2 = 74,
|
||||
SR_PM_HI3 = 75,
|
||||
SR_PM_HI4 = 76,
|
||||
SR_PM_HI5 = 77,
|
||||
SR_PM_HI6 = 78,
|
||||
SR_PM_HI7 = 79,
|
||||
SR_CLOCKLO = 80,
|
||||
SR_CLOCKHI = 81,
|
||||
SR_GLOBALTIMERLO = 82,
|
||||
SR_GLOBALTIMERHI = 83,
|
||||
SR84 = 84,
|
||||
SR85 = 85,
|
||||
SR86 = 86,
|
||||
SR87 = 87,
|
||||
SR88 = 88,
|
||||
SR89 = 89,
|
||||
SR90 = 90,
|
||||
SR91 = 91,
|
||||
SR92 = 92,
|
||||
SR93 = 93,
|
||||
SR94 = 94,
|
||||
SR95 = 95,
|
||||
SR_HWTASKID = 96,
|
||||
SR_CIRCULARQUEUEENTRYINDEX = 97,
|
||||
SR_CIRCULARQUEUEENTRYADDRESSLOW = 98,
|
||||
SR_CIRCULARQUEUEENTRYADDRESSHIGH = 99,
|
||||
};
|
||||
|
||||
[[nodiscard]] IR::U32 Read(IR::IREmitter& ir, SpecialRegister special_register) {
|
||||
|
@ -103,6 +133,16 @@ enum class SpecialRegister : u64 {
|
|||
return ir.Imm32(Common::BitCast<u32>(1.0f));
|
||||
case SpecialRegister::SR_LANEID:
|
||||
return ir.LaneId();
|
||||
case SpecialRegister::SR_EQMASK:
|
||||
return ir.SubgroupEqMask();
|
||||
case SpecialRegister::SR_LTMASK:
|
||||
return ir.SubgroupLtMask();
|
||||
case SpecialRegister::SR_LEMASK:
|
||||
return ir.SubgroupLeMask();
|
||||
case SpecialRegister::SR_GTMASK:
|
||||
return ir.SubgroupGtMask();
|
||||
case SpecialRegister::SR_GEMASK:
|
||||
return ir.SubgroupGeMask();
|
||||
default:
|
||||
throw NotImplementedException("S2R special register {}", special_register);
|
||||
}
|
||||
|
|
|
@ -414,6 +414,13 @@ void VisitUsages(Info& info, IR::Inst& inst) {
|
|||
inst.GetAssociatedPseudoOperation(IR::Opcode::GetSparseFromOp) != nullptr;
|
||||
break;
|
||||
}
|
||||
case IR::Opcode::SubgroupEqMask:
|
||||
case IR::Opcode::SubgroupLtMask:
|
||||
case IR::Opcode::SubgroupLeMask:
|
||||
case IR::Opcode::SubgroupGtMask:
|
||||
case IR::Opcode::SubgroupGeMask:
|
||||
info.uses_subgroup_mask = true;
|
||||
break;
|
||||
case IR::Opcode::VoteAll:
|
||||
case IR::Opcode::VoteAny:
|
||||
case IR::Opcode::VoteEqual:
|
||||
|
|
|
@ -99,6 +99,7 @@ struct Info {
|
|||
bool uses_sparse_residency{};
|
||||
bool uses_demote_to_helper_invocation{};
|
||||
bool uses_subgroup_vote{};
|
||||
bool uses_subgroup_mask{};
|
||||
bool uses_fswzadd{};
|
||||
|
||||
IR::Type used_constant_buffer_types{};
|
||||
|
|
Loading…
Reference in a new issue