Skip to content

Commit 1476489

Browse files
authored
Capability type checking. (shader-slang#3530)
* Capability type checking. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent c15e7ad commit 1476489

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

51 files changed

+1869
-489
lines changed

source/slang/core.meta.slang

+3
Original file line numberDiff line numberDiff line change
@@ -2475,6 +2475,9 @@ attribute_syntax [vk_image_format(format : String)] : FormatAttribute;
24752475
__attributeTarget(Decl)
24762476
attribute_syntax [allow(diagnostic: String)] : AllowAttribute;
24772477

2478+
__attributeTarget(Decl)
2479+
attribute_syntax[require(capability)] : RequireCapabilityAttribute;
2480+
24782481
// Linking
24792482
__attributeTarget(Decl)
24802483
attribute_syntax [__extern] : ExternAttribute;

source/slang/hlsl.meta.slang

+10-27
Original file line numberDiff line numberDiff line change
@@ -4716,7 +4716,6 @@ T GetAttributeAtVertex(T attribute, uint vertexIndex)
47164716
{
47174717
case hlsl:
47184718
__intrinsic_asm "GetAttributeAtVertex";
4719-
case _GL_NV_fragment_shader_barycentric:
47204719
case _GL_EXT_fragment_shader_barycentric:
47214720
__intrinsic_asm "$0[$1]";
47224721
case spirv:
@@ -4749,7 +4748,6 @@ vector<T,N> GetAttributeAtVertex(vector<T,N> attribute, uint vertexIndex)
47494748
{
47504749
case hlsl:
47514750
__intrinsic_asm "GetAttributeAtVertex";
4752-
case _GL_NV_fragment_shader_barycentric:
47534751
case _GL_EXT_fragment_shader_barycentric:
47544752
__intrinsic_asm "$0[$1]";
47554753
case spirv:
@@ -4782,7 +4780,6 @@ matrix<T,N,M> GetAttributeAtVertex(matrix<T,N,M> attribute, uint vertexIndex)
47824780
{
47834781
case hlsl:
47844782
__intrinsic_asm "GetAttributeAtVertex";
4785-
case _GL_NV_fragment_shader_barycentric:
47864783
case _GL_EXT_fragment_shader_barycentric:
47874784
__intrinsic_asm "$0[$1]";
47884785
case spirv:
@@ -9288,8 +9285,7 @@ struct BuiltInTriangleIntersectionAttributes
92889285
// `executeCallableNV` is the GLSL intrinsic that will be used to implement
92899286
// `CallShader()` for GLSL-based targets.
92909287
//
9291-
__target_intrinsic(GL_NV_ray_tracing, "executeCallableNV")
9292-
__target_intrinsic(GL_EXT_ray_tracing, "executeCallableEXT")
9288+
__target_intrinsic(_GL_EXT_ray_tracing, "executeCallableEXT")
92939289
void __executeCallable(uint shaderIndex, int payloadLocation);
92949290

92959291
// Next is the custom intrinsic that will compute the payload location
@@ -9335,8 +9331,7 @@ void CallShader(uint shaderIndex, inout Payload payload)
93359331

93369332
// 10.3.2
93379333

9338-
__target_intrinsic(GL_NV_ray_tracing, "traceNV")
9339-
__target_intrinsic(GL_EXT_ray_tracing, "traceRayEXT")
9334+
__target_intrinsic(_GL_EXT_ray_tracing, "traceRayEXT")
93409335
void __traceRay(
93419336
RaytracingAccelerationStructure AccelerationStructure,
93429337
uint RayFlags,
@@ -9528,7 +9523,6 @@ bool __reportIntersection(float tHit, uint hitKind)
95289523
__target_switch
95299524
{
95309525
case _GL_EXT_ray_tracing: __intrinsic_asm "reportIntersectionEXT";
9531-
case _GL_NV_ray_tracing: __intrinsic_asm "reportIntersectionNV";
95329526
case spirv:
95339527
return spirv_asm {
95349528
result:$$bool = OpReportIntersectionKHR $tHit $hitKind;
@@ -9555,7 +9549,6 @@ void IgnoreHit()
95559549
{
95569550
case hlsl: __intrinsic_asm "IgnoreHit";
95579551
case _GL_EXT_ray_tracing: __intrinsic_asm "ignoreIntersectionEXT;";
9558-
case _GL_NV_ray_tracing: __intrinsic_asm "ignoreIntersectionNV";
95599552
case cuda: __intrinsic_asm "optixIgnoreIntersection";
95609553
case spirv: spirv_asm { OpIgnoreIntersectionKHR; %_ = OpLabel };
95619554
}
@@ -9568,7 +9561,6 @@ void AcceptHitAndEndSearch()
95689561
{
95699562
case hlsl: __intrinsic_asm "AcceptHitAndEndSearch";
95709563
case _GL_EXT_ray_tracing: __intrinsic_asm "terminateRayEXT;";
9571-
case _GL_NV_ray_tracing: __intrinsic_asm "terminateRayNV";
95729564
case cuda: __intrinsic_asm "optixTerminateRay";
95739565
case spirv: spirv_asm { OpTerminateRayKHR; %_ = OpLabel };
95749566
}
@@ -9587,7 +9579,6 @@ uint3 DispatchRaysIndex()
95879579
{
95889580
case hlsl: __intrinsic_asm "DispatchRaysIndex";
95899581
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchIDEXT)";
9590-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchIDNV)";
95919582
case cuda: __intrinsic_asm "optixGetLaunchIndex";
95929583
case spirv:
95939584
return spirv_asm {
@@ -9602,7 +9593,6 @@ uint3 DispatchRaysDimensions()
96029593
{
96039594
case hlsl: __intrinsic_asm "DispatchRaysDimensions";
96049595
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_LaunchSizeEXT)";
9605-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_LaunchSizeNV)";
96069596
case cuda: __intrinsic_asm "optixGetLaunchDimensions";
96079597
case spirv:
96089598
return spirv_asm {
@@ -9619,7 +9609,6 @@ float3 WorldRayOrigin()
96199609
{
96209610
case hlsl: __intrinsic_asm "WorldRayOrigin";
96219611
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginEXT)";
9622-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayOriginNV)";
96239612
case cuda: __intrinsic_asm "optixGetWorldRayOrigin";
96249613
case spirv:
96259614
return spirv_asm {
@@ -9634,7 +9623,6 @@ float3 WorldRayDirection()
96349623
{
96359624
case hlsl: __intrinsic_asm "WorldRayDirection";
96369625
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionEXT)";
9637-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldRayDirectionNV)";
96389626
case cuda: __intrinsic_asm "optixGetWorldRayDirection";
96399627
case spirv:
96409628
return spirv_asm {
@@ -9649,7 +9637,6 @@ float RayTMin()
96499637
{
96509638
case hlsl: __intrinsic_asm "RayTMin";
96519639
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTminEXT)";
9652-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTminNV)";
96539640
case cuda: __intrinsic_asm "optixGetRayTmin";
96549641
case spirv:
96559642
return spirv_asm {
@@ -9674,7 +9661,6 @@ float RayTCurrent()
96749661
{
96759662
case hlsl: __intrinsic_asm "RayTCurrent";
96769663
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_RayTmaxEXT)";
9677-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_RayTmaxNV)";
96789664
case cuda: __intrinsic_asm "optixGetRayTmax";
96799665
case spirv:
96809666
return spirv_asm {
@@ -9689,7 +9675,6 @@ uint RayFlags()
96899675
{
96909676
case hlsl: __intrinsic_asm "RayFlags";
96919677
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsEXT)";
9692-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_IncomingRayFlagsNV)";
96939678
case cuda: __intrinsic_asm "optixGetRayFlags";
96949679
case spirv:
96959680
return spirv_asm {
@@ -9720,7 +9705,6 @@ uint InstanceID()
97209705
{
97219706
case hlsl: __intrinsic_asm "InstanceID";
97229707
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexEXT)";
9723-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_InstanceCustomIndexNV)";
97249708
case cuda: __intrinsic_asm "optixGetInstanceId";
97259709
case spirv:
97269710
return spirv_asm {
@@ -9749,7 +9733,6 @@ float3 ObjectRayOrigin()
97499733
{
97509734
case hlsl: __intrinsic_asm "ObjectRayOrigin";
97519735
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginEXT)";
9752-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayOriginNV)";
97539736
case cuda: __intrinsic_asm "optixGetObjectRayOrigin";
97549737
case spirv:
97559738
return spirv_asm {
@@ -9764,7 +9747,6 @@ float3 ObjectRayDirection()
97649747
{
97659748
case hlsl: __intrinsic_asm "ObjectRayDirection";
97669749
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionEXT)";
9767-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectRayDirectionNV)";
97689750
case cuda: __intrinsic_asm "optixGetObjectRayDirection";
97699751
case spirv:
97709752
return spirv_asm {
@@ -9781,7 +9763,6 @@ float3x4 ObjectToWorld3x4()
97819763
{
97829764
case hlsl: __intrinsic_asm "ObjectToWorld3x4";
97839765
case _GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldEXT)";
9784-
case _GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_ObjectToWorldNV)";
97859766
case spirv:
97869767
return spirv_asm {
97879768
%mat:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3);
@@ -9796,7 +9777,6 @@ float3x4 WorldToObject3x4()
97969777
{
97979778
case hlsl: __intrinsic_asm "WorldToObject3x4";
97989779
case _GL_EXT_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectEXT)";
9799-
case _GL_NV_ray_tracing: __intrinsic_asm "transpose(gl_WorldToObjectNV)";
98009780
case spirv:
98019781
return spirv_asm {
98029782
%mat:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3);
@@ -9811,7 +9791,6 @@ float4x3 ObjectToWorld4x3()
98119791
{
98129792
case hlsl: __intrinsic_asm "ObjectToWorld4x3";
98139793
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldEXT)";
9814-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_ObjectToWorldNV)";
98159794
case spirv:
98169795
return spirv_asm {
98179796
result:$$float4x3 = OpLoad builtin(ObjectToWorldKHR:float4x3);
@@ -9825,7 +9804,6 @@ float4x3 WorldToObject4x3()
98259804
{
98269805
case hlsl: __intrinsic_asm "WorldToObject4x3";
98279806
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_WorldToObjectEXT)";
9828-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_WorldToObjectNV)";
98299807
case spirv:
98309808
return spirv_asm {
98319809
result:$$float4x3 = OpLoad builtin(WorldToObjectKHR:float4x3);
@@ -9872,7 +9850,6 @@ uint HitKind()
98729850
{
98739851
case hlsl: __intrinsic_asm "HitKind";
98749852
case _GL_EXT_ray_tracing: __intrinsic_asm "(gl_HitKindEXT)";
9875-
case _GL_NV_ray_tracing: __intrinsic_asm "(gl_HitKindNV)";
98769853
case cuda: __intrinsic_asm "optixGetHitKind";
98779854
case spirv:
98789855
return spirv_asm {
@@ -11874,6 +11851,7 @@ void debugBreak();
1187411851

1187511852
[__requiresNVAPI]
1187611853
__glsl_extension(GL_EXT_shader_realtime_clock)
11854+
[require(shaderclock)]
1187711855
uint getRealtimeClockLow()
1187811856
{
1187911857
__target_switch
@@ -11886,14 +11864,18 @@ uint getRealtimeClockLow()
1188611864
__intrinsic_asm "clock";
1188711865
case spirv:
1188811866
return getRealtimeClock().x;
11867+
case cpp:
11868+
__intrinsic_asm "(uint32_t)std::chrono::high_resolution_clock::now().time_since_epoch().count()";
1188911869
}
1189011870
}
1189111871

11872+
__target_intrinsic(cpp, "std::chrono::high_resolution_clock::now().time_since_epoch().count()")
1189211873
__target_intrinsic(cuda, "clock64")
11893-
int64_t __cudaGetRealtimeClock();
11874+
int64_t __cudaCppGetRealtimeClock();
1189411875

1189511876
[__requiresNVAPI]
1189611877
__glsl_extension(GL_EXT_shader_realtime_clock)
11878+
[require(shaderclock)]
1189711879
uint2 getRealtimeClock()
1189811880
{
1189911881
__target_switch
@@ -11903,7 +11885,8 @@ uint2 getRealtimeClock()
1190311885
case glsl:
1190411886
__intrinsic_asm "clockRealtime2x32EXT()";
1190511887
case cuda:
11906-
int64_t ticks = __cudaGetRealtimeClock();
11888+
case cpp:
11889+
int64_t ticks = __cudaCppGetRealtimeClock();
1190711890
return uint2(uint(ticks), uint(uint64_t(ticks) >> 32));
1190811891
case spirv:
1190911892
return spirv_asm

source/slang/slang-ast-base.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
#include "slang-generated-ast.h"
88
#include "slang-ast-reflect.h"
9-
9+
#include "slang-capability.h"
1010
#include "slang-serialize-reflection.h"
1111

1212
// This file defines the primary base classes for the hierarchy of
@@ -695,6 +695,11 @@ class ModifiableSyntaxNode : public SyntaxNode
695695
bool hasModifier() { return findModifier<T>() != nullptr; }
696696
};
697697

698+
struct DeclReferenceWithLoc
699+
{
700+
Decl* referencedDecl;
701+
SourceLoc referenceLoc;
702+
};
698703

699704
// An intermediate type to represent either a single declaration, or a group of declarations
700705
class DeclBase : public ModifiableSyntaxNode
@@ -716,6 +721,7 @@ class Decl : public DeclBase
716721
DeclRefBase* getDefaultDeclRef();
717722

718723
NameLoc nameAndLoc;
724+
CapabilitySet inferredCapabilityRequirements;
719725

720726
RefPtr<MarkupEntry> markup;
721727

@@ -736,6 +742,8 @@ class Decl : public DeclBase
736742
}
737743
bool isChildOf(Decl* other) const;
738744

745+
// Track the decl reference that caused the requirement of a capability atom.
746+
SLANG_UNREFLECTED Dictionary<CapabilityAtom, DeclReferenceWithLoc> capabilityRequirementProvenance;
739747
private:
740748
SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr;
741749
SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1;

source/slang/slang-ast-dump.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -692,6 +692,31 @@ struct ASTDumpContext
692692
m_writer->emit("}");
693693
}
694694

695+
void dump(const CapabilitySet& capSet)
696+
{
697+
m_writer->emit("capability_set(");
698+
bool isFirstSet = true;
699+
for (auto& set : capSet.getExpandedAtoms())
700+
{
701+
if (!isFirstSet)
702+
{
703+
m_writer->emit(" | ");
704+
}
705+
bool isFirst = true;
706+
for (auto atom : set.getExpandedAtoms())
707+
{
708+
if (!isFirst)
709+
{
710+
m_writer->emit("+");
711+
}
712+
dump(capabilityNameToString((CapabilityName)atom));
713+
isFirst = false;
714+
}
715+
isFirstSet = false;
716+
}
717+
m_writer->emit(")");
718+
}
719+
695720
void dumpObjectFull(NodeBase* node);
696721

697722
ASTDumpContext(SourceWriter* writer, ASTDumpUtil::Flags flags, ASTDumpUtil::Style dumpStyle):

source/slang/slang-ast-iterator.h

+28-19
Original file line numberDiff line numberDiff line change
@@ -3,16 +3,14 @@
33

44
namespace Slang
55
{
6-
template <typename Callback>
6+
template <typename Callback, typename Filter>
77
struct ASTIterator
88
{
99
const Callback& callback;
10-
UnownedStringSlice fileName;
11-
SourceManager* sourceManager;
12-
ASTIterator(const Callback& func, SourceManager* manager, UnownedStringSlice sourceFileName)
10+
const Filter& filter;
11+
ASTIterator(const Callback& func, const Filter& filterFunc)
1312
: callback(func)
14-
, fileName(sourceFileName)
15-
, sourceManager(manager)
13+
, filter(filterFunc)
1614
{}
1715

1816
void visitDecl(DeclBase* decl);
@@ -429,13 +427,11 @@ struct ASTIterator
429427
};
430428
};
431429

432-
template <typename CallbackFunc>
433-
void ASTIterator<CallbackFunc>::visitDecl(DeclBase* decl)
430+
template <typename CallbackFunc, typename FilterFunc>
431+
void ASTIterator<CallbackFunc, FilterFunc>::visitDecl(DeclBase* decl)
434432
{
435433
// Don't look at the decl if it is defined in a different file.
436-
if (!as<NamespaceDeclBase>(decl) && !sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual)
437-
.pathInfo.foundPath.getUnownedSlice()
438-
.endsWithCaseInsensitive(fileName))
434+
if (!filter(decl))
439435
return;
440436

441437
maybeDispatchCallback(decl);
@@ -490,24 +486,23 @@ void ASTIterator<CallbackFunc>::visitDecl(DeclBase* decl)
490486
}
491487
}
492488
}
493-
template <typename CallbackFunc>
494-
void ASTIterator<CallbackFunc>::visitExpr(Expr* expr)
489+
template <typename CallbackFunc, typename FilterFunc>
490+
void ASTIterator<CallbackFunc, FilterFunc>::visitExpr(Expr* expr)
495491
{
496492
ASTIteratorExprVisitor visitor(this);
497493
visitor.dispatchIfNotNull(expr);
498494
}
499-
template <typename CallbackFunc>
500-
void ASTIterator<CallbackFunc>::visitStmt(Stmt* stmt)
495+
template <typename CallbackFunc, typename FilterFunc>
496+
void ASTIterator<CallbackFunc, FilterFunc>::visitStmt(Stmt* stmt)
501497
{
502498
ASTIteratorStmtVisitor visitor(this);
503499
visitor.dispatchIfNotNull(stmt);
504500
}
505501

506-
template <typename Func>
507-
void iterateAST(
508-
UnownedStringSlice fileName, SourceManager* manager, SyntaxNode* node, const Func& f)
502+
template <typename Func, typename FilterFunc>
503+
void iterateAST(SyntaxNode* node, const FilterFunc& filterFunc, const Func& f)
509504
{
510-
ASTIterator<Func> iter(f, manager, fileName);
505+
ASTIterator<Func, FilterFunc> iter(f, filterFunc);
511506
if (auto decl = as<Decl>(node))
512507
{
513508
iter.visitDecl(decl);
@@ -521,4 +516,18 @@ void iterateAST(
521516
iter.visitStmt(stmt);
522517
}
523518
}
519+
520+
template <typename Func>
521+
void iterateASTWithLanguageServerFilter(
522+
UnownedStringSlice fileName, SourceManager* sourceManager, SyntaxNode* node, const Func& f)
523+
{
524+
auto filter = [&](DeclBase* decl)
525+
{
526+
return as<NamespaceDeclBase>(decl) ||
527+
sourceManager->getHumaneLoc(decl->loc, SourceLocType::Actual)
528+
.pathInfo.foundPath.getUnownedSlice()
529+
.endsWithCaseInsensitive(fileName);
530+
};
531+
iterateAST(node, filter, f);
532+
}
524533
} // namespace Slang

0 commit comments

Comments
 (0)