@@ -354,6 +354,8 @@ namespace Slang
354
354
355
355
virtual void processReferencedDecl (Decl* decl) = 0;
356
356
357
+ virtual void processDeclModifiers (Decl* decl) = 0;
358
+
357
359
void dispatchIfNotNull (Stmt* stmt)
358
360
{
359
361
if (!stmt)
@@ -462,6 +464,7 @@ namespace Slang
462
464
{
463
465
dispatchIfNotNull (expr->type .type );
464
466
dispatchIfNotNull (expr->declRef .declRefBase );
467
+ processDeclModifiers (expr->declRef .getDecl ());
465
468
}
466
469
void visitStaticMemberExpr (StaticMemberExpr* expr)
467
470
{
@@ -9813,10 +9816,11 @@ namespace Slang
9813
9816
typedef SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>> Base;
9814
9817
9815
9818
const ProcessFunc& handleReferenceFunc;
9816
-
9819
+ RequireCapabilityAttribute* maybeRequireCapability;
9817
9820
SemanticsContext& outerContext;
9818
- CapabilityDeclReferenceVisitor (const ProcessFunc& processFunc, SemanticsContext& outer)
9821
+ CapabilityDeclReferenceVisitor (const ProcessFunc& processFunc, RequireCapabilityAttribute* maybeRequireCapability, SemanticsContext& outer)
9819
9822
: handleReferenceFunc(processFunc)
9823
+ , maybeRequireCapability(maybeRequireCapability)
9820
9824
, outerContext(outer)
9821
9825
, SemanticsDeclReferenceVisitor<CapabilityDeclReferenceVisitor<ProcessFunc>>(outer)
9822
9826
{
@@ -9828,16 +9832,54 @@ namespace Slang
9828
9832
loc = Base::sourceLocStack.getLast ();
9829
9833
handleReferenceFunc (decl, decl->inferredCapabilityRequirements , loc);
9830
9834
}
9835
+ virtual void processDeclModifiers (Decl* decl)
9836
+ {
9837
+ if (decl)
9838
+ handleReferenceFunc (decl, decl->inferredCapabilityRequirements , decl->loc );
9839
+ }
9831
9840
void visitDiscardStmt (DiscardStmt* stmt)
9832
9841
{
9833
9842
handleReferenceFunc (stmt, CapabilitySet (CapabilityName::fragment), stmt->loc );
9834
9843
}
9835
9844
void visitTargetSwitchStmt (TargetSwitchStmt* stmt)
9836
9845
{
9837
9846
CapabilitySet set;
9838
- for (auto targetCase : stmt->targetCases )
9847
+ auto targetCaseCount = stmt->targetCases .getCount ();
9848
+ for (Index targetCaseIndex = 0 ; targetCaseIndex < targetCaseCount; targetCaseIndex++)
9839
9849
{
9840
- auto targetCap = CapabilitySet (CapabilityName (targetCase->capability ));
9850
+ // We may recieve a `default:` case for a `__target_switch`. If this is the case,
9851
+ // we must resolve the target capability for a non empty set of `calling_functions_targets`:
9852
+ // ``` default_target = calling_functions_targets-{other_case_targets} ```
9853
+ //
9854
+ // * `calling_functions_capability` = `requirement attribute` of the calling function; if missing
9855
+ // we can assume it is `any_target`
9856
+ //
9857
+ // * `{other_case_targets}` = set of all capabilities all `case` statments target inside the `__target_switch`
9858
+
9859
+ // If we do not handle `default:`, the codegen will fail when trying to find a specific
9860
+ // codegen target not handled explicitly by a `case` statment.
9861
+ // We must also ensure the `default` case is last so we have priority to hit `case` statments and can preprocess
9862
+ // `case` statments before the `default` case.
9863
+ CapabilitySet targetCap;
9864
+ if (CapabilityName (stmt->targetCases [targetCaseIndex]->capability ) == CapabilityName::Invalid)
9865
+ {
9866
+ if (targetCaseCount - 1 != targetCaseIndex)
9867
+ {
9868
+ for (Index i = targetCaseIndex; i < targetCaseCount - 1 ; i++)
9869
+ std::swap (stmt->targetCases [i], stmt->targetCases [i + 1 ]);
9870
+ continue ;
9871
+ }
9872
+
9873
+ if (!maybeRequireCapability)
9874
+ targetCap = (CapabilitySet (CapabilityName::any_target).getTargetsThisIsMissingFromOther (set));
9875
+ else
9876
+ targetCap = (maybeRequireCapability->capabilitySet .getTargetsThisIsMissingFromOther (set));
9877
+ }
9878
+ else
9879
+ {
9880
+ targetCap = CapabilitySet (CapabilityName (stmt->targetCases [targetCaseIndex]->capability ));
9881
+ }
9882
+ auto targetCase = stmt->targetCases [targetCaseIndex];
9841
9883
auto oldCap = targetCap;
9842
9884
auto bodyCap = getStatementCapabilityUsage (this , targetCase->body );
9843
9885
targetCap.join (bodyCap);
@@ -9851,16 +9893,17 @@ namespace Slang
9851
9893
set.canonicalize ();
9852
9894
handleReferenceFunc (stmt, set, stmt->loc );
9853
9895
}
9896
+
9854
9897
void visitRequireCapabilityDecl (RequireCapabilityDecl* decl)
9855
9898
{
9856
9899
handleReferenceFunc (decl, decl->inferredCapabilityRequirements , decl->loc );
9857
9900
}
9858
9901
};
9859
9902
9860
9903
template <typename ProcessFunc>
9861
- void visitReferencedDecls (SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, const ProcessFunc& func)
9904
+ void visitReferencedDecls (SemanticsContext& context, NodeBase* node, SourceLoc initialLoc, RequireCapabilityAttribute* maybeRequireCapability, const ProcessFunc& func)
9862
9905
{
9863
- CapabilityDeclReferenceVisitor<ProcessFunc> visitor (func, context);
9906
+ CapabilityDeclReferenceVisitor<ProcessFunc> visitor (func, maybeRequireCapability, context);
9864
9907
visitor.sourceLocStack .add (initialLoc);
9865
9908
9866
9909
if (auto val = as<Val>(node))
@@ -9879,7 +9922,7 @@ namespace Slang
9879
9922
return CapabilitySet ();
9880
9923
9881
9924
CapabilitySet inferredRequirements;
9882
- visitReferencedDecls (*visitor, stmt, stmt->loc , [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9925
+ visitReferencedDecls (*visitor, stmt, stmt->loc , nullptr , [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9883
9926
{
9884
9927
_propagateRequirement (visitor, inferredRequirements, stmt, node, nodeCaps, refLoc);
9885
9928
});
@@ -9888,11 +9931,7 @@ namespace Slang
9888
9931
9889
9932
void SemanticsDeclCapabilityVisitor::checkVarDeclCommon (VarDeclBase* varDecl)
9890
9933
{
9891
- visitReferencedDecls (*this , varDecl->type .type , varDecl->loc , [this , varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9892
- {
9893
- _propagateRequirement (this , varDecl->inferredCapabilityRequirements , varDecl, node, nodeCaps, refLoc);
9894
- });
9895
- visitReferencedDecls (*this , varDecl->initExpr , varDecl->loc , [this , varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9934
+ visitReferencedDecls (*this , varDecl->type .type , varDecl->loc , varDecl->findModifier <RequireCapabilityAttribute>(), [this , varDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9896
9935
{
9897
9936
_propagateRequirement (this , varDecl->inferredCapabilityRequirements , varDecl, node, nodeCaps, refLoc);
9898
9937
});
@@ -9958,7 +9997,7 @@ namespace Slang
9958
9997
ensureDecl (member, DeclCheckState::CapabilityChecked);
9959
9998
_propagateRequirement (this , funcDecl->inferredCapabilityRequirements , funcDecl, member, member->inferredCapabilityRequirements , member->loc );
9960
9999
}
9961
- visitReferencedDecls (*this , funcDecl->body , funcDecl->loc , [this , funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
10000
+ visitReferencedDecls (*this , funcDecl->body , funcDecl->loc , funcDecl-> findModifier <RequireCapabilityAttribute>(), [this , funcDecl](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
9962
10001
{
9963
10002
_propagateRequirement (this , funcDecl->inferredCapabilityRequirements , funcDecl, node, nodeCaps, refLoc);
9964
10003
});
@@ -9972,7 +10011,7 @@ namespace Slang
9972
10011
_propagateRequirement (this , funcDecl->inferredCapabilityRequirements , funcDecl, parentAggTypeDecl, parentAggTypeDecl->inferredCapabilityRequirements , funcDecl->loc );
9973
10012
}
9974
10013
}
9975
-
10014
+
9976
10015
auto declaredCaps = getDeclaredCapabilitySet (funcDecl);
9977
10016
9978
10017
if (!declaredCaps.isEmpty ())
@@ -9996,12 +10035,13 @@ namespace Slang
9996
10035
if (declaredCaps.isEmpty ())
9997
10036
{
9998
10037
// If the user has not declared any capabilities,
9999
- // we should diagnose an error if this is a public symbol.
10038
+ // we should diagnose a warning if any_target is not
10039
+ // a super-set by exact atoms.
10000
10040
if (vis == DeclVisibility::Public && !funcDecl->inferredCapabilityRequirements .isEmpty ())
10001
10041
{
10002
10042
if (!getModuleDecl (funcDecl)->isInLegacyLanguage )
10003
10043
{
10004
- if (funcDecl->inferredCapabilityRequirements != getAnyPlatformCapabilitySet ())
10044
+ if (! funcDecl->inferredCapabilityRequirements . isExactSubset ( getAnyPlatformCapabilitySet () ))
10005
10045
{
10006
10046
diagnoseCapabilityErrors (
10007
10047
getSink (),
@@ -10019,6 +10059,9 @@ namespace Slang
10019
10059
{
10020
10060
// For public decls, we need to enforce that the function
10021
10061
// only uses capabilities that it declares.
10062
+ // At a minimum we will propagate shader requirements to our
10063
+ // function from calling children in all cases so the parent
10064
+ // can enforce shader targets correctly and propagate to `main`
10022
10065
const CapabilityConjunctionSet* failedAvailableCapabilityConjunction = nullptr ;
10023
10066
if (!CapabilitySet::checkCapabilityRequirement (
10024
10067
declaredCaps,
@@ -10028,6 +10071,8 @@ namespace Slang
10028
10071
diagnoseUndeclaredCapability (funcDecl, Diagnostics::useOfUndeclaredCapability, failedAvailableCapabilityConjunction);
10029
10072
funcDecl->inferredCapabilityRequirements = declaredCaps;
10030
10073
}
10074
+ else
10075
+ funcDecl->inferredCapabilityRequirements .simpleJoinWithSetMask (declaredCaps, CapabilityName::stage);
10031
10076
}
10032
10077
else
10033
10078
{
@@ -10165,7 +10210,7 @@ namespace Slang
10165
10210
while (traceLevels > 0 )
10166
10211
{
10167
10212
refDecl = nullptr ;
10168
- visitReferencedDecls (*visitor, decl, decl->loc , [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
10213
+ visitReferencedDecls (*visitor, decl, decl->loc , decl-> findModifier <RequireCapabilityAttribute>(), [&](SyntaxNode* node, const CapabilitySet& nodeCaps, SourceLoc refLoc)
10169
10214
{
10170
10215
if (nodeCaps.isIncompatibleWith (incompatibleAtom))
10171
10216
{
@@ -10197,6 +10242,8 @@ namespace Slang
10197
10242
{
10198
10243
if (decl->inferredCapabilityRequirements .getExpandedAtoms ().getCount () == 0 )
10199
10244
return ;
10245
+ if (!failedAvailableSet)
10246
+ return ;
10200
10247
10201
10248
// There are two causes for why type checking failed on failedAvailableSet.
10202
10249
// The first scenario is that failedAvailableSet defines a set of capabilities on a
0 commit comments