Skip to content

Commit 51ad07d

Browse files
authored
Improve performance when compiling small shaders. (#6396)
Improve performance when compiling small shaders. Avoid copying witness table entries that are not getting used during linking. Avoid copying auto-diff related decorations and derivative functions during linking, if the user modules doesn't use autodiff. Cache operator overload resolution results on global session, so each new Session doesn't need to repetitively run through overload resolution from scratch.
1 parent 0101e5a commit 51ad07d

25 files changed

+766
-122
lines changed

docs/design/stdlib-intrinsics.md

-6
Original file line numberDiff line numberDiff line change
@@ -114,12 +114,6 @@ Sections of the `expansion` string that are to be replaced are prefixed by the `
114114
* $XH - Ray tracing hit object attribute
115115
* $P - Type-based prefix as used for CUDA and C++ targets (I8 for int8_t, F32 - float etc)
116116

117-
## __specialized_for_target(target)
118-
119-
Specialized for target allows defining an implementation *body* for a particular target. The target is the same as is used for [__target_intrinsic](#target-intrinsic).
120-
121-
A declaration can consist of multiple definitions with bodies (for each target) using, `specialized_for_target`, as well as having `target_intrinsic` if that is applicable for a target.
122-
123117
## __attributeTarget(astClassName)
124118

125119
For an attribute, specifies the AST class (and derived class) the attribute can be applied to.

source/core/slang-string.h

+51
Original file line numberDiff line numberDiff line change
@@ -790,6 +790,57 @@ class SLANG_RT_API String
790790
UnownedStringSlice getUnownedSlice() const { return StringRepresentation::asSlice(m_buffer); }
791791
};
792792

793+
class ImmutableHashedString
794+
{
795+
public:
796+
String slice;
797+
HashCode64 hashCode;
798+
ImmutableHashedString()
799+
: hashCode(0)
800+
{
801+
}
802+
ImmutableHashedString(const UnownedStringSlice& slice)
803+
: slice(slice), hashCode(slice.getHashCode())
804+
{
805+
}
806+
ImmutableHashedString(const char* begin, const char* end)
807+
: slice(begin, end), hashCode(slice.getHashCode())
808+
{
809+
}
810+
ImmutableHashedString(const char* begin, size_t len)
811+
: slice(UnownedStringSlice(begin, len)), hashCode(slice.getHashCode())
812+
{
813+
}
814+
ImmutableHashedString(const char* begin)
815+
: slice(begin), hashCode(slice.getHashCode())
816+
{
817+
}
818+
ImmutableHashedString(const String& str)
819+
: slice(str), hashCode(str.getHashCode())
820+
{
821+
}
822+
ImmutableHashedString(String&& str)
823+
: slice(_Move(str)), hashCode(str.getHashCode())
824+
{
825+
}
826+
ImmutableHashedString(const ImmutableHashedString& other) = default;
827+
ImmutableHashedString& operator=(const ImmutableHashedString& other) = default;
828+
bool operator==(const ImmutableHashedString& other) const
829+
{
830+
return hashCode == other.hashCode && slice == other.slice;
831+
}
832+
bool operator!=(const ImmutableHashedString& other) const
833+
{
834+
return hashCode != other.hashCode || slice != other.slice;
835+
}
836+
bool operator==(const UnownedStringSlice& other) const { return slice == other; }
837+
bool operator!=(const UnownedStringSlice& other) const { return slice != other; }
838+
bool operator==(const String& other) const { return slice == other.getUnownedSlice(); }
839+
bool operator!=(const String& other) const { return slice != other.getUnownedSlice(); }
840+
bool operator==(const char* other) const { return slice == UnownedStringSlice(other); }
841+
HashCode64 getHashCode() const { return hashCode; }
842+
};
843+
793844
class SLANG_RT_API StringBuilder : public String
794845
{
795846
private:

source/slang/core.meta.slang

+6-4
Original file line numberDiff line numberDiff line change
@@ -603,6 +603,7 @@ void static_assert(constexpr bool condition, NativeString errorMessage);
603603
///
604604
///
605605
__magic_type(DifferentiableType)
606+
[KnownBuiltin("IDifferentiable")]
606607
interface IDifferentiable
607608
{
608609
// Note: the compiler implementation requires the `Differential` associated type to be defined
@@ -645,6 +646,7 @@ interface IDifferentiable
645646
/// @remarks Support for this interface is still experimental and subject to change.
646647
///
647648
__magic_type(DifferentiablePtrType)
649+
[KnownBuiltin("IDifferentiablePtr")]
648650
interface IDifferentiablePtrType
649651
{
650652
__builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) )
@@ -2598,20 +2600,20 @@ for(auto fixity : kIncDecFixities)
25982600
$(fixity.qual)
25992601
__generic<T : __BuiltinArithmeticType>
26002602
[__unsafeForceInlineEarly]
2601-
T operator$(op.name)(in out T value)
2602-
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
2603+
T operator$(op.name)( in out T value)
2604+
{ $(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
26032605

26042606
$(fixity.qual)
26052607
__generic<T : __BuiltinArithmeticType, let N : int>
26062608
[__unsafeForceInlineEarly]
26072609
vector<T,N> operator$(op.name)(in out vector<T,N> value)
2608-
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
2610+
{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
26092611

26102612
$(fixity.qual)
26112613
__generic<T : __BuiltinArithmeticType, let R : int, let C : int, let L : int>
26122614
[__unsafeForceInlineEarly]
26132615
matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value)
2614-
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
2616+
{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }
26152617

26162618
$(fixity.qual)
26172619
__generic<T, let addrSpace : uint64_t>

source/slang/diff.meta.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
176176
// for internal use.
177177
//
178178
[__AutoDiffBuiltin]
179-
export struct NullDifferential : IDifferentiable
179+
struct NullDifferential : IDifferentiable
180180
{
181181
// for now, we'll use at least one field to make sure the type is non-empty
182182
uint dummy;

source/slang/hlsl.meta.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -20309,7 +20309,7 @@ void ReorderThread( HitObject HitOrMiss )
2030920309
///
2031020310
/// There doesn't appear to be an equivalent for debugBreak for HLSL
2031120311

20312-
20312+
[require(glsl)]
2031320313
__specialized_for_target(glsl)
2031420314
[[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]]
2031520315
void __glslDebugBreak();

source/slang/slang-check-decl.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -10452,6 +10452,12 @@ void SemanticsVisitor::importModuleIntoScope(Scope* scope, ModuleDecl* moduleDec
1045210452
{
1045310453
return;
1045410454
}
10455+
10456+
if (getText(moduleDecl->getName()) == "glsl")
10457+
{
10458+
getShared()->glslModuleDecl = moduleDecl;
10459+
}
10460+
1045510461
importedModulesList.add(moduleDecl);
1045610462
importedModulesSet.add(moduleDecl);
1045710463

source/slang/slang-check-impl.h

+28-5
Original file line numberDiff line numberDiff line change
@@ -172,15 +172,17 @@ struct BasicTypeKeyPair
172172

173173
struct OperatorOverloadCacheKey
174174
{
175-
intptr_t operatorName;
175+
int32_t operatorName;
176+
bool isGLSLMode;
176177
BasicTypeKey args[2];
177178
bool operator==(OperatorOverloadCacheKey key) const
178179
{
179-
return operatorName == key.operatorName && args[0] == key.args[0] && args[1] == key.args[1];
180+
return operatorName == key.operatorName && args[0] == key.args[0] &&
181+
args[1] == key.args[1] && isGLSLMode == key.isGLSLMode;
180182
}
181183
HashCode getHashCode() const
182184
{
183-
return combineHash((int)(UInt64)(void*)(operatorName), args[0].getRaw(), args[1].getRaw());
185+
return combineHash(operatorName, args[0].getRaw(), args[1].getRaw(), isGLSLMode ? 1 : 0);
184186
}
185187
bool fromOperatorExpr(OperatorExpr* opExpr)
186188
{
@@ -299,10 +301,28 @@ struct OverloadCandidate
299301
SubstitutionSet subst;
300302
};
301303

302-
struct TypeCheckingCache
304+
struct ResolvedOperatorOverload
303305
{
304-
Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
306+
// The resolved decl.
307+
Decl* decl;
308+
309+
// The cached overload candidate in the current TypeCheckingCache.
310+
// Note that a `OverloadCandidate` object is not migratable over different
311+
// Linkages (compile sessions), so we will need to use `cacheVersion` to track
312+
// if this `candidate` is valid for the current session. If not, we will
313+
// recreate it from `decl`.
314+
OverloadCandidate candidate;
315+
// The version of the TypeCheckingCache for which the cached candidate is valid.
316+
int cacheVersion;
317+
};
318+
319+
struct TypeCheckingCache : public RefObject
320+
{
321+
Dictionary<OperatorOverloadCacheKey, ResolvedOperatorOverload> resolvedOperatorOverloadCache;
305322
Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;
323+
324+
// The version used to invalidate the cached declRefs in ResolvedOperatorOverload entries.
325+
int version = 0;
306326
};
307327

308328
enum class CoercionSite
@@ -635,6 +655,9 @@ struct SharedSemanticsContext : public RefObject
635655

636656
DiagnosticSink* m_sink = nullptr;
637657

658+
// Whether the current module has imported the GLSL module.
659+
ModuleDecl* glslModuleDecl = nullptr;
660+
638661
/// (optional) modules that comes from previously processed translation units in the
639662
/// front-end request that are made visible to the module being checked. This allows
640663
/// `import` to use them instead of trying to find the files in file system.

source/slang/slang-check-overload.cpp

+49-22
Original file line numberDiff line numberDiff line change
@@ -2509,27 +2509,6 @@ String SemanticsVisitor::getCallSignatureString(OverloadResolveContext& context)
25092509
Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
25102510
{
25112511
OverloadResolveContext context;
2512-
// check if this is a core module operator call, if so we want to use cached results
2513-
// to speed up compilation
2514-
bool shouldAddToCache = false;
2515-
OperatorOverloadCacheKey key;
2516-
TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache();
2517-
if (auto opExpr = as<OperatorExpr>(expr))
2518-
{
2519-
if (key.fromOperatorExpr(opExpr))
2520-
{
2521-
OverloadCandidate candidate;
2522-
if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
2523-
{
2524-
context.bestCandidateStorage = candidate;
2525-
context.bestCandidate = &context.bestCandidateStorage;
2526-
}
2527-
else
2528-
{
2529-
shouldAddToCache = true;
2530-
}
2531-
}
2532-
}
25332512

25342513
// Look at the base expression for the call, and figure out how to invoke it.
25352514
auto funcExpr = expr->functionExpr;
@@ -2569,6 +2548,43 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
25692548
context.loc = expr->loc;
25702549
context.sourceScope = m_outerScope;
25712550
context.baseExpr = GetBaseExpr(funcExpr);
2551+
2552+
// check if this is a core module operator call, if so we want to use cached results
2553+
// to speed up compilation
2554+
bool shouldAddToCache = false;
2555+
OperatorOverloadCacheKey key;
2556+
TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache();
2557+
if (auto opExpr = as<OperatorExpr>(expr))
2558+
{
2559+
if (key.fromOperatorExpr(opExpr))
2560+
{
2561+
key.isGLSLMode = getShared()->glslModuleDecl != nullptr;
2562+
ResolvedOperatorOverload candidate;
2563+
if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
2564+
{
2565+
// We should only use the cached candidate if it is persistent direct declref
2566+
// created from GlobalSession's ASTBuilder, or it is created in the current Linkage.
2567+
if (candidate.cacheVersion == typeCheckingCache->version ||
2568+
as<DirectDeclRef>(candidate.candidate.item.declRef.declRefBase))
2569+
{
2570+
context.bestCandidateStorage = candidate.candidate;
2571+
context.bestCandidate = &context.bestCandidateStorage;
2572+
}
2573+
else
2574+
{
2575+
LookupResultItem overloadCandidate = {};
2576+
overloadCandidate.declRef = getOuterGenericOrSelf(candidate.decl);
2577+
AddDeclRefOverloadCandidates(overloadCandidate, context, 0);
2578+
shouldAddToCache = true;
2579+
}
2580+
}
2581+
else
2582+
{
2583+
shouldAddToCache = true;
2584+
}
2585+
}
2586+
}
2587+
25722588
// We run a special case here where an `InvokeExpr`
25732589
// with a single argument where the base/func expression names
25742590
// a type should always be treated as an explicit type coercion
@@ -2731,7 +2747,18 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
27312747
// We will report errors for this one candidate, then, to give
27322748
// the user the most help we can.
27332749
if (shouldAddToCache)
2734-
typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate;
2750+
{
2751+
if (isFromCoreModule(context.bestCandidate->item.declRef.getDecl()) ||
2752+
getShared()->glslModuleDecl ==
2753+
getModuleDecl(context.bestCandidate->item.declRef.getDecl()))
2754+
{
2755+
ResolvedOperatorOverload overloadResult;
2756+
overloadResult.candidate = *context.bestCandidate;
2757+
overloadResult.decl = context.bestCandidate->item.declRef.getDecl();
2758+
overloadResult.cacheVersion = typeCheckingCache->version;
2759+
typeCheckingCache->resolvedOperatorOverloadCache[key] = overloadResult;
2760+
}
2761+
}
27352762

27362763
// Now that we have resolved the overload candidate, we need to undo an `openExistential`
27372764
// operation that was applied to `out` arguments.

source/slang/slang-compiler.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -2209,7 +2209,7 @@ class Linkage : public RefObject, public slang::ISession
22092209
TypeCheckingCache* getTypeCheckingCache();
22102210
void destroyTypeCheckingCache();
22112211

2212-
TypeCheckingCache* m_typeCheckingCache = nullptr;
2212+
RefPtr<RefObject> m_typeCheckingCache = nullptr;
22132213

22142214
// Modules that have been dynamically loaded via `import`
22152215
//
@@ -3589,6 +3589,9 @@ class Session : public RefObject, public slang::IGlobalSession
35893589

35903590
int m_typeDictionarySize = 0;
35913591

3592+
RefPtr<RefObject> m_typeCheckingCache;
3593+
TypeCheckingCache* getTypeCheckingCache();
3594+
35923595
private:
35933596
struct BuiltinModuleInfo
35943597
{

source/slang/slang-diagnostic-defs.h

+2-1
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,7 @@ DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'")
4949
DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'")
5050
DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'")
5151
DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'")
52+
DIAGNOSTIC(-1, Note, seeCallOfFunc, "see call to '$0'")
5253
DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition")
5354
DIAGNOSTIC(
5455
-1,
@@ -2309,7 +2310,7 @@ DIAGNOSTIC(
23092310
41402,
23102311
Error,
23112312
staticAssertionConditionNotConstant,
2312-
"condition for static assertion cannot be evaluated at the compile-time.")
2313+
"condition for static assertion cannot be evaluated at compile time.")
23132314

23142315
DIAGNOSTIC(
23152316
41402,

source/slang/slang-emit.cpp

+26
Original file line numberDiff line numberDiff line change
@@ -475,6 +475,31 @@ void calcRequiredLoweringPassSet(
475475
}
476476
}
477477

478+
void diagnoseCallStack(IRInst* inst, DiagnosticSink* sink)
479+
{
480+
static const int maxDepth = 5;
481+
for (int i = 0; i < maxDepth; i++)
482+
{
483+
auto func = getParentFunc(inst);
484+
if (!func)
485+
return;
486+
bool shouldContinue = false;
487+
for (auto use = func->firstUse; use; use = use->nextUse)
488+
{
489+
auto user = use->getUser();
490+
if (auto call = as<IRCall>(user))
491+
{
492+
sink->diagnose(call, Diagnostics::seeCallOfFunc, func);
493+
inst = call;
494+
shouldContinue = true;
495+
break;
496+
}
497+
}
498+
if (!shouldContinue)
499+
return;
500+
}
501+
}
502+
478503
bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink)
479504
{
480505
switch (inst->getOp())
@@ -498,6 +523,7 @@ bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink)
498523
{
499524
sink->diagnose(inst, Diagnostics::staticAssertionFailureWithoutMessage);
500525
}
526+
diagnoseCallStack(inst, sink);
501527
}
502528
}
503529
else

source/slang/slang-ir-autodiff-fwd.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -2012,6 +2012,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
20122012
case kIROp_MakeArrayFromElement:
20132013
case kIROp_MakeTuple:
20142014
case kIROp_MakeValuePack:
2015+
case kIROp_BuiltinCast:
20152016
return transcribeConstruct(builder, origInst);
20162017
case kIROp_MakeStruct:
20172018
return transcribeMakeStruct(builder, origInst);

0 commit comments

Comments
 (0)