Skip to content

Commit f9bcad3

Browse files
authored
Initial pass to add capability declarations to stdlib intrinsics. (shader-slang#3912)
1 parent 2da28c5 commit f9bcad3

22 files changed

+2813
-1214
lines changed

source/slang/glsl.meta.slang

+1,040-635
Large diffs are not rendered by default.

source/slang/hlsl.meta.slang

+986-449
Large diffs are not rendered by default.

source/slang/slang-ast-dump.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -298,7 +298,10 @@ struct ASTDumpContext
298298
{
299299
m_writer->emit(v);
300300
}
301-
301+
void dump(CapabilityName v)
302+
{
303+
m_writer->emit(capabilityNameToString(v));
304+
}
302305

303306
void dump(const SemanticVersion& version)
304307
{

source/slang/slang-capabilities.capdef

+355-73
Large diffs are not rendered by default.

source/slang/slang-capability.cpp

+116
Original file line numberDiff line numberDiff line change
@@ -988,6 +988,104 @@ void CapabilitySet::canonicalize()
988988
m_conjunctions.sort();
989989
}
990990

991+
CapabilitySet CapabilitySet::getTargetsThisIsMissingFromOther(const CapabilitySet& other)
992+
{
993+
CapabilitySet conflicts{};
994+
List<CapabilityConjunctionSet> textualTargetsNotHandled;
995+
for (auto conjunction : this->m_conjunctions)
996+
{
997+
textualTargetsNotHandled.add({});
998+
auto& currentList = textualTargetsNotHandled.getLast();
999+
for (auto thatNode : conjunction.getExpandedAtoms())
1000+
{
1001+
// To make this faster we can make an assumption that the nodes are:
1002+
// {textualTarget, targetAbstract(), targetAbstract(), nonTarget}
1003+
// this assumption is not being used since it relies on ordering of .capdef file
1004+
if (_getInfo(thatNode).abstractBase == CapabilityName::target)
1005+
currentList.getExpandedAtoms().add(thatNode);
1006+
}
1007+
}
1008+
for (auto& thatConjunction : other.m_conjunctions)
1009+
{
1010+
// Worth the check to early leave due to ~5*5 elements to loop around
1011+
if (textualTargetsNotHandled.getCount() == 0)
1012+
break;
1013+
1014+
for (int i = 0 ; i < textualTargetsNotHandled.getCount(); i++)
1015+
{
1016+
auto& textualTargets = textualTargetsNotHandled[i];
1017+
1018+
if (textualTargets.countIntersectionWith(thatConjunction) != textualTargets.getExpandedAtoms().getCount())
1019+
continue;
1020+
1021+
textualTargetsNotHandled[i] = textualTargets.makeEmpty();
1022+
}
1023+
}
1024+
CapabilitySet set;
1025+
for (auto& i : textualTargetsNotHandled)
1026+
{
1027+
if (i.isEmpty())
1028+
continue;
1029+
set.unionWith(i);
1030+
}
1031+
return set;
1032+
}
1033+
1034+
// We only run 'join' logic on "this" conjunctions which are compatiable with "other" conjunctions.
1035+
// We only add specific nodes which satisfy the abstractMask.
1036+
// Any non-compatible conjunctions with "other"s cconjunctions will be preserved and unmodified.
1037+
void CapabilitySet::simpleJoinWithSetMask(const CapabilitySet& other, CapabilityName abstractMask)
1038+
{
1039+
CapabilitySet resultSet;
1040+
HashSet<CapabilityConjunctionSet*> setUsed;
1041+
// get used abstract mask nodes per conjunction so we can trivially check
1042+
// if we need to add the abstract mask node to avoid duplicates
1043+
List<HashSet<CapabilityAtom>> abstractMaskNodeInUse;
1044+
abstractMaskNodeInUse.growToCount(m_conjunctions.getCount());
1045+
for (int i = 0; i < m_conjunctions.getCount(); i++)
1046+
{
1047+
auto& thisConjunction = m_conjunctions[i];
1048+
auto& setOfInUseNode = abstractMaskNodeInUse[i];
1049+
1050+
for (auto& atom : thisConjunction.getExpandedAtoms())
1051+
{
1052+
if (_getInfo(atom).abstractBase != abstractMask)
1053+
continue;
1054+
setOfInUseNode.add(atom);
1055+
}
1056+
}
1057+
1058+
for (auto& thatConjunction : other.m_conjunctions)
1059+
{
1060+
for (int i = 0; i < m_conjunctions.getCount(); i++)
1061+
{
1062+
auto& thisConjunction = m_conjunctions[i];
1063+
auto& setOfInUseNode = abstractMaskNodeInUse[i];
1064+
CapabilityConjunctionSet conjunctionToAddToResultSet;
1065+
1066+
if (thisConjunction.isIncompatibleWith(thatConjunction))
1067+
continue;
1068+
conjunctionToAddToResultSet = thisConjunction;
1069+
setUsed.add(&thisConjunction);
1070+
for (auto atom : thatConjunction.getExpandedAtoms())
1071+
{
1072+
if (_getInfo(atom).abstractBase != abstractMask
1073+
|| setOfInUseNode.contains(atom))
1074+
continue;
1075+
conjunctionToAddToResultSet.getExpandedAtoms().add(atom);
1076+
}
1077+
conjunctionToAddToResultSet.getExpandedAtoms().sort();
1078+
resultSet.unionWith(conjunctionToAddToResultSet);
1079+
}
1080+
}
1081+
for (auto& c : m_conjunctions)
1082+
{
1083+
if (!setUsed.contains(&c))
1084+
resultSet.m_conjunctions.add(c);
1085+
}
1086+
m_conjunctions = resultSet.m_conjunctions;
1087+
}
1088+
9911089
void CapabilitySet::join(const CapabilitySet& other)
9921090
{
9931091
if (isEmpty() || other.isInvalid())
@@ -1176,6 +1274,24 @@ bool CapabilitySet::checkCapabilityRequirement(CapabilitySet const& available, C
11761274
return true;
11771275
}
11781276

1277+
bool CapabilitySet::isExactSubset(CapabilitySet const& maybeSuperSet)
1278+
{
1279+
// This should only be used when absolutely required due to the
1280+
// cost for complex sets. Simple sets are fine (glsl|spirv...)
1281+
for (auto& thisCon : m_conjunctions)
1282+
{
1283+
bool foundEqualCon = false;
1284+
for (auto& thatCon : maybeSuperSet.m_conjunctions)
1285+
{
1286+
if (thisCon == thatCon)
1287+
foundEqualCon = true;
1288+
}
1289+
if (foundEqualCon == false)
1290+
return false;
1291+
}
1292+
return true;
1293+
}
1294+
11791295
void printDiagnosticArg(StringBuilder& sb, const CapabilitySet& capSet)
11801296
{
11811297
bool isFirstSet = true;

source/slang/slang-capability.h

+6
Original file line numberDiff line numberDiff line change
@@ -209,6 +209,10 @@ struct CapabilitySet
209209

210210
void unionWith(const CapabilityConjunctionSet& other);
211211

212+
void simpleJoinWithSetMask(const CapabilitySet& other, CapabilityName abstractMask);
213+
214+
CapabilitySet getTargetsThisIsMissingFromOther(const CapabilitySet& other);
215+
212216
void canonicalize();
213217

214218
/// Are these two capability sets equal?
@@ -226,6 +230,8 @@ struct CapabilitySet
226230

227231
static bool checkCapabilityRequirement(CapabilitySet const& available, CapabilitySet const& required, const CapabilityConjunctionSet*& outFailedAvailableSet);
228232

233+
bool isExactSubset(CapabilitySet const& maybeSuperSet);
234+
229235
private:
230236
// The underlying representation we use is a list of conjunctions.
231237
//

source/slang/slang-check-decl.cpp

+64-17
Original file line numberDiff line numberDiff line change
@@ -354,6 +354,8 @@ namespace Slang
354354

355355
virtual void processReferencedDecl(Decl* decl) = 0;
356356

357+
virtual void processDeclModifiers(Decl* decl) = 0;
358+
357359
void dispatchIfNotNull(Stmt* stmt)
358360
{
359361
if (!stmt)
@@ -462,6 +464,7 @@ namespace Slang
462464
{
463465
dispatchIfNotNull(expr->type.type);
464466
dispatchIfNotNull(expr->declRef.declRefBase);
467+
processDeclModifiers(expr->declRef.getDecl());
465468
}
466469
void visitStaticMemberExpr(StaticMemberExpr* expr)
467470
{
@@ -9813,10 +9816,11 @@ namespace Slang
98139816
typedef SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> Base;
98149817

98159818
const ProcessFunc& handleReferenceFunc;
9816-
9819+
RequireCapabilityAttribute* maybeRequireCapability;
98179820
SemanticsContext& outerContext;
9818-
CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, SemanticsContext& outer)
9821+
CapabilityDeclReferenceVisitor(const ProcessFunc& processFunc, RequireCapabilityAttribute* maybeRequireCapability, SemanticsContext& outer)
98199822
: handleReferenceFunc(processFunc)
9823+
, maybeRequireCapability(maybeRequireCapability)
98209824
, outerContext(outer)
98219825
, SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>(outer)
98229826
{
@@ -9828,16 +9832,54 @@ namespace Slang
98289832
loc = Base::sourceLocStack.getLast();
98299833
handleReferenceFunc(decl, decl->inferredCapabilityRequirements, loc);
98309834
}
9835+
virtual void processDeclModifiers(Decl* decl)
9836+
{
9837+
if (decl)
9838+
handleReferenceFunc(decl, decl->inferredCapabilityRequirements, decl->loc);
9839+
}
98319840
void visitDiscardStmt(DiscardStmt* stmt)
98329841
{
98339842
handleReferenceFunc(stmt, CapabilitySet(CapabilityName::fragment), stmt->loc);
98349843
}
98359844
void visitTargetSwitchStmt(TargetSwitchStmt* stmt)
98369845
{
98379846
CapabilitySet set;
9838-
for (auto targetCase : stmt->targetCases)
9847+
auto targetCaseCount = stmt->targetCases.getCount();
9848+
for (Index targetCaseIndex = 0; targetCaseIndex < targetCaseCount; targetCaseIndex++)
98399849
{
9840-
auto targetCap = CapabilitySet(CapabilityName(targetCase->capability));
9850+
// We may recieve a `default:` case for a `__target_switch`. If this is the case,
9851+
// we must resolve the target capability for a non empty set of `calling_functions_targets`:
9852+
// ``` default_target = calling_functions_targets-{other_case_targets} ```
9853+
//
9854+
// * `calling_functions_capability` = `requirement attribute` of the calling function; if missing
9855+
// we can assume it is `any_target`
9856+
//
9857+
// * `{other_case_targets}` = set of all capabilities all `case` statments target inside the `__target_switch`
9858+
9859+
// If we do not handle `default:`, the codegen will fail when trying to find a specific
9860+
// codegen target not handled explicitly by a `case` statment.
9861+
// We must also ensure the `default` case is last so we have priority to hit `case` statments and can preprocess
9862+
// `case` statments before the `default` case.
9863+
CapabilitySet targetCap;
9864+
if (CapabilityName(stmt->targetCases[targetCaseIndex]->capability) == CapabilityName::Invalid)
9865+
{
9866+
if (targetCaseCount - 1 != targetCaseIndex)
9867+
{
9868+
for (Index i = targetCaseIndex; i < targetCaseCount - 1; i++)
9869+
std::swap(stmt->targetCases[i], stmt->targetCases[i + 1]);
9870+
continue;
9871+
}
9872+
9873+
if (!maybeRequireCapability)
9874+
targetCap = (CapabilitySet(CapabilityName::any_target).getTargetsThisIsMissingFromOther(set));
9875+
else
9876+
targetCap = (maybeRequireCapability->capabilitySet.getTargetsThisIsMissingFromOther(set));
9877+
}
9878+
else
9879+
{
9880+
targetCap = CapabilitySet(CapabilityName(stmt->targetCases[targetCaseIndex]->capability));
9881+
}
9882+
auto targetCase = stmt->targetCases[targetCaseIndex];
98419883
auto oldCap = targetCap;
98429884
auto bodyCap = getStatementCapabilityUsage(this, targetCase->body);
98439885
targetCap.join(bodyCap);
@@ -9851,16 +9893,17 @@ namespace Slang
98519893
set.canonicalize();
98529894
handleReferenceFunc(stmt, set, stmt->loc);
98539895
}
9896+
98549897
void visitRequireCapabilityDecl(RequireCapabilityDecl* decl)
98559898
{
98569899
handleReferenceFunc(decl, decl->inferredCapabilityRequirements, decl->loc);
98579900
}
98589901
};
98599902

98609903
template<typename ProcessFunc>
9861-
void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, const ProcessFunc& func)
9904+
void visitReferencedDecls(SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, RequireCapabilityAttribute* maybeRequireCapability, const ProcessFunc& func)
98629905
{
9863-
CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, context);
9906+
CapabilityDeclReferenceVisitor<ProcessFunc> visitor(func, maybeRequireCapability, context);
98649907
visitor.sourceLocStack.add(initialLoc);
98659908

98669909
if (auto val = as<Val>(node))
@@ -9879,7 +9922,7 @@ namespace Slang
98799922
return CapabilitySet();
98809923

98819924
CapabilitySet inferredRequirements;
9882-
visitReferencedDecls(*visitor, stmt, stmt->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9925+
visitReferencedDecls(*visitor, stmt, stmt->loc, nullptr, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
98839926
{
98849927
_propagateRequirement(visitor, inferredRequirements, stmt, node, nodeCaps, refLoc);
98859928
});
@@ -9888,11 +9931,7 @@ namespace Slang
98889931

98899932
void SemanticsDeclCapabilityVisitor::checkVarDeclCommon(VarDeclBase* varDecl)
98909933
{
9891-
visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9892-
{
9893-
_propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc);
9894-
});
9895-
visitReferencedDecls(*this, varDecl->initExpr, varDecl->loc, [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9934+
visitReferencedDecls(*this, varDecl->type.type, varDecl->loc, varDecl->findModifier<RequireCapabilityAttribute>(), [this, varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
98969935
{
98979936
_propagateRequirement(this, varDecl->inferredCapabilityRequirements, varDecl, node, nodeCaps, refLoc);
98989937
});
@@ -9958,7 +9997,7 @@ namespace Slang
99589997
ensureDecl(member, DeclCheckState::CapabilityChecked);
99599998
_propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, member, member->inferredCapabilityRequirements, member->loc);
99609999
}
9961-
visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
10000+
visitReferencedDecls(*this, funcDecl->body, funcDecl->loc, funcDecl->findModifier<RequireCapabilityAttribute>(), [this, funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
996210001
{
996310002
_propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, node, nodeCaps, refLoc);
996410003
});
@@ -9972,7 +10011,7 @@ namespace Slang
997210011
_propagateRequirement(this, funcDecl->inferredCapabilityRequirements, funcDecl, parentAggTypeDecl, parentAggTypeDecl->inferredCapabilityRequirements, funcDecl->loc);
997310012
}
997410013
}
9975-
10014+
997610015
auto declaredCaps = getDeclaredCapabilitySet(funcDecl);
997710016

997810017
if (!declaredCaps.isEmpty())
@@ -9996,12 +10035,13 @@ namespace Slang
999610035
if (declaredCaps.isEmpty())
999710036
{
999810037
// If the user has not declared any capabilities,
9999-
// we should diagnose an error if this is a public symbol.
10038+
// we should diagnose a warning if any_target is not
10039+
// a super-set by exact atoms.
1000010040
if (vis == DeclVisibility::Public && !funcDecl->inferredCapabilityRequirements.isEmpty())
1000110041
{
1000210042
if (!getModuleDecl(funcDecl)->isInLegacyLanguage)
1000310043
{
10004-
if (funcDecl->inferredCapabilityRequirements != getAnyPlatformCapabilitySet())
10044+
if (!funcDecl->inferredCapabilityRequirements.isExactSubset(getAnyPlatformCapabilitySet()))
1000510045
{
1000610046
diagnoseCapabilityErrors(
1000710047
getSink(),
@@ -10019,6 +10059,9 @@ namespace Slang
1001910059
{
1002010060
// For public decls, we need to enforce that the function
1002110061
// only uses capabilities that it declares.
10062+
// At a minimum we will propagate shader requirements to our
10063+
// function from calling children in all cases so the parent
10064+
// can enforce shader targets correctly and propagate to `main`
1002210065
const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr;
1002310066
if (!CapabilitySet::checkCapabilityRequirement(
1002410067
declaredCaps,
@@ -10028,6 +10071,8 @@ namespace Slang
1002810071
diagnoseUndeclaredCapability(funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction);
1002910072
funcDecl->inferredCapabilityRequirements = declaredCaps;
1003010073
}
10074+
else
10075+
funcDecl->inferredCapabilityRequirements.simpleJoinWithSetMask(declaredCaps, CapabilityName::stage);
1003110076
}
1003210077
else
1003310078
{
@@ -10165,7 +10210,7 @@ namespace Slang
1016510210
while (traceLevels > 0)
1016610211
{
1016710212
refDecl = nullptr;
10168-
visitReferencedDecls(*visitor, decl, decl->loc, [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
10213+
visitReferencedDecls(*visitor, decl, decl->loc, decl->findModifier<RequireCapabilityAttribute>(), [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
1016910214
{
1017010215
if (nodeCaps.isIncompatibleWith(incompatibleAtom))
1017110216
{
@@ -10197,6 +10242,8 @@ namespace Slang
1019710242
{
1019810243
if (decl->inferredCapabilityRequirements.getExpandedAtoms().getCount() == 0)
1019910244
return;
10245+
if(!failedAvailableSet)
10246+
return;
1020010247

1020110248
// There are two causes for why type checking failed on failedAvailableSet.
1020210249
// The first scenario is that failedAvailableSet defines a set of capabilities on a

0 commit comments

Comments
 (0)