Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve performance when compiling small shaders. #6396

Merged
merged 18 commits into from
Feb 23, 2025
6 changes: 0 additions & 6 deletions docs/design/stdlib-intrinsics.md
Original file line number Diff line number Diff line change
@@ -114,12 +114,6 @@ Sections of the `expansion` string that are to be replaced are prefixed by the `
* $XH - Ray tracing hit object attribute
* $P - Type-based prefix as used for CUDA and C++ targets (I8 for int8_t, F32 - float etc)

## __specialized_for_target(target)

Specialized for target allows defining an implementation *body* for a particular target. The target is the same as is used for [__target_intrinsic](#target-intrinsic).

A declaration can consist of multiple definitions with bodies (for each target) using, `specialized_for_target`, as well as having `target_intrinsic` if that is applicable for a target.

## __attributeTarget(astClassName)

For an attribute, specifies the AST class (and derived class) the attribute can be applied to.
51 changes: 51 additions & 0 deletions source/core/slang-string.h
Original file line number Diff line number Diff line change
@@ -790,6 +790,57 @@ class SLANG_RT_API String
UnownedStringSlice getUnownedSlice() const { return StringRepresentation::asSlice(m_buffer); }
};

class ImmutableHashedString
{
public:
String slice;
HashCode64 hashCode;
Comment on lines +796 to +797
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Probably we should make the members private.
Not sure if we want to make them const as well since it is Immutable.

ImmutableHashedString()
: hashCode(0)
{
}
ImmutableHashedString(const UnownedStringSlice& slice)
: slice(slice), hashCode(slice.getHashCode())
{
}
ImmutableHashedString(const char* begin, const char* end)
: slice(begin, end), hashCode(slice.getHashCode())
{
}
ImmutableHashedString(const char* begin, size_t len)
: slice(UnownedStringSlice(begin, len)), hashCode(slice.getHashCode())
{
}
ImmutableHashedString(const char* begin)
: slice(begin), hashCode(slice.getHashCode())
{
}
ImmutableHashedString(const String& str)
: slice(str), hashCode(str.getHashCode())
{
}
ImmutableHashedString(String&& str)
: slice(_Move(str)), hashCode(str.getHashCode())
{
}
ImmutableHashedString(const ImmutableHashedString& other) = default;
ImmutableHashedString& operator=(const ImmutableHashedString& other) = default;
bool operator==(const ImmutableHashedString& other) const
{
return hashCode == other.hashCode && slice == other.slice;
}
bool operator!=(const ImmutableHashedString& other) const
{
return hashCode != other.hashCode || slice != other.slice;
}
bool operator==(const UnownedStringSlice& other) const { return slice == other; }
bool operator!=(const UnownedStringSlice& other) const { return slice != other; }
bool operator==(const String& other) const { return slice == other.getUnownedSlice(); }
bool operator!=(const String& other) const { return slice != other.getUnownedSlice(); }
bool operator==(const char* other) const { return slice == UnownedStringSlice(other); }
HashCode64 getHashCode() const { return hashCode; }
};

class SLANG_RT_API StringBuilder : public String
{
private:
10 changes: 6 additions & 4 deletions source/slang/core.meta.slang
Original file line number Diff line number Diff line change
@@ -603,6 +603,7 @@ void static_assert(constexpr bool condition, NativeString errorMessage);
///
///
__magic_type(DifferentiableType)
[KnownBuiltin("IDifferentiable")]
interface IDifferentiable
{
// Note: the compiler implementation requires the `Differential` associated type to be defined
@@ -645,6 +646,7 @@ interface IDifferentiable
/// @remarks Support for this interface is still experimental and subject to change.
///
__magic_type(DifferentiablePtrType)
[KnownBuiltin("IDifferentiablePtr")]
interface IDifferentiablePtrType
{
__builtin_requirement($( (int)BuiltinRequirementKind::DifferentialPtrType) )
@@ -2598,20 +2600,20 @@ for(auto fixity : kIncDecFixities)
$(fixity.qual)
__generic<T : __BuiltinArithmeticType>
[__unsafeForceInlineEarly]
T operator$(op.name)(in out T value)
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
T operator$(op.name)( in out T value)
{ $(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }

$(fixity.qual)
__generic<T : __BuiltinArithmeticType, let N : int>
[__unsafeForceInlineEarly]
vector<T,N> operator$(op.name)(in out vector<T,N> value)
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }

$(fixity.qual)
__generic<T : __BuiltinArithmeticType, let R : int, let C : int, let L : int>
[__unsafeForceInlineEarly]
matrix<T,R,C> operator$(op.name)(in out matrix<T,R,C,L> value)
{$(fixity.bodyPrefix) value = value $(op.binOp) T(1); return $(fixity.returnVal); }
{$(fixity.bodyPrefix) value = value $(op.binOp) __builtin_cast<T>(1); return $(fixity.returnVal); }

$(fixity.qual)
__generic<T, let addrSpace : uint64_t>
2 changes: 1 addition & 1 deletion source/slang/diff.meta.slang
Original file line number Diff line number Diff line change
@@ -176,7 +176,7 @@ attribute_syntax [NoDiffThis] : NoDiffThisAttribute;
// for internal use.
//
[__AutoDiffBuiltin]
export struct NullDifferential : IDifferentiable
struct NullDifferential : IDifferentiable
{
// for now, we'll use at least one field to make sure the type is non-empty
uint dummy;
2 changes: 1 addition & 1 deletion source/slang/hlsl.meta.slang
Original file line number Diff line number Diff line change
@@ -20309,7 +20309,7 @@ void ReorderThread( HitObject HitOrMiss )
///
/// There doesn't appear to be an equivalent for debugBreak for HLSL


[require(glsl)]
__specialized_for_target(glsl)
[[vk::spirv_instruction(1, "NonSemantic.DebugBreak")]]
void __glslDebugBreak();
Comment on lines +20312 to 20315
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe we should move this to glsl.meta.slang?

6 changes: 6 additions & 0 deletions source/slang/slang-check-decl.cpp
Original file line number Diff line number Diff line change
@@ -10452,6 +10452,12 @@ void SemanticsVisitor::importModuleIntoScope(Scope* scope, ModuleDecl* moduleDec
{
return;
}

if (getText(moduleDecl->getName()) == "glsl")
{
getShared()->glslModuleDecl = moduleDecl;
}

importedModulesList.add(moduleDecl);
importedModulesSet.add(moduleDecl);

33 changes: 28 additions & 5 deletions source/slang/slang-check-impl.h
Original file line number Diff line number Diff line change
@@ -172,15 +172,17 @@ struct BasicTypeKeyPair

struct OperatorOverloadCacheKey
{
intptr_t operatorName;
int32_t operatorName;
bool isGLSLMode;
BasicTypeKey args[2];
bool operator==(OperatorOverloadCacheKey key) const
{
return operatorName == key.operatorName && args[0] == key.args[0] && args[1] == key.args[1];
return operatorName == key.operatorName && args[0] == key.args[0] &&
args[1] == key.args[1] && isGLSLMode == key.isGLSLMode;
}
HashCode getHashCode() const
{
return combineHash((int)(UInt64)(void*)(operatorName), args[0].getRaw(), args[1].getRaw());
return combineHash(operatorName, args[0].getRaw(), args[1].getRaw(), isGLSLMode ? 1 : 0);
}
bool fromOperatorExpr(OperatorExpr* opExpr)
{
@@ -299,10 +301,28 @@ struct OverloadCandidate
SubstitutionSet subst;
};

struct TypeCheckingCache
struct ResolvedOperatorOverload
{
Dictionary<OperatorOverloadCacheKey, OverloadCandidate> resolvedOperatorOverloadCache;
// The resolved decl.
Decl* decl;

// The cached overload candidate in the current TypeCheckingCache.
// Note that a `OverloadCandidate` object is not migratable over different
// Linkages (compile sessions), so we will need to use `cacheVersion` to track
// if this `candidate` is valid for the current session. If not, we will
// recreate it from `decl`.
OverloadCandidate candidate;
// The version of the TypeCheckingCache for which the cached candidate is valid.
int cacheVersion;
};

struct TypeCheckingCache : public RefObject
{
Dictionary<OperatorOverloadCacheKey, ResolvedOperatorOverload> resolvedOperatorOverloadCache;
Dictionary<BasicTypeKeyPair, ConversionCost> conversionCostCache;

// The version used to invalidate the cached declRefs in ResolvedOperatorOverload entries.
int version = 0;
};

enum class CoercionSite
@@ -635,6 +655,9 @@ struct SharedSemanticsContext : public RefObject

DiagnosticSink* m_sink = nullptr;

// Whether the current module has imported the GLSL module.
ModuleDecl* glslModuleDecl = nullptr;

/// (optional) modules that comes from previously processed translation units in the
/// front-end request that are made visible to the module being checked. This allows
/// `import` to use them instead of trying to find the files in file system.
71 changes: 49 additions & 22 deletions source/slang/slang-check-overload.cpp
Original file line number Diff line number Diff line change
@@ -2509,27 +2509,6 @@ String SemanticsVisitor::getCallSignatureString(OverloadResolveContext& context)
Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
{
OverloadResolveContext context;
// check if this is a core module operator call, if so we want to use cached results
// to speed up compilation
bool shouldAddToCache = false;
OperatorOverloadCacheKey key;
TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache();
if (auto opExpr = as<OperatorExpr>(expr))
{
if (key.fromOperatorExpr(opExpr))
{
OverloadCandidate candidate;
if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
{
context.bestCandidateStorage = candidate;
context.bestCandidate = &context.bestCandidateStorage;
}
else
{
shouldAddToCache = true;
}
}
}

// Look at the base expression for the call, and figure out how to invoke it.
auto funcExpr = expr->functionExpr;
@@ -2569,6 +2548,43 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
context.loc = expr->loc;
context.sourceScope = m_outerScope;
context.baseExpr = GetBaseExpr(funcExpr);

// check if this is a core module operator call, if so we want to use cached results
// to speed up compilation
bool shouldAddToCache = false;
OperatorOverloadCacheKey key;
TypeCheckingCache* typeCheckingCache = getLinkage()->getTypeCheckingCache();
if (auto opExpr = as<OperatorExpr>(expr))
{
if (key.fromOperatorExpr(opExpr))
{
key.isGLSLMode = getShared()->glslModuleDecl != nullptr;
ResolvedOperatorOverload candidate;
if (typeCheckingCache->resolvedOperatorOverloadCache.tryGetValue(key, candidate))
{
// We should only use the cached candidate if it is persistent direct declref
// created from GlobalSession's ASTBuilder, or it is created in the current Linkage.
if (candidate.cacheVersion == typeCheckingCache->version ||
as<DirectDeclRef>(candidate.candidate.item.declRef.declRefBase))
{
context.bestCandidateStorage = candidate.candidate;
context.bestCandidate = &context.bestCandidateStorage;
}
else
{
LookupResultItem overloadCandidate = {};
overloadCandidate.declRef = getOuterGenericOrSelf(candidate.decl);
AddDeclRefOverloadCandidates(overloadCandidate, context, 0);
shouldAddToCache = true;
}
}
else
{
shouldAddToCache = true;
}
}
}

// We run a special case here where an `InvokeExpr`
// with a single argument where the base/func expression names
// a type should always be treated as an explicit type coercion
@@ -2731,7 +2747,18 @@ Expr* SemanticsVisitor::ResolveInvoke(InvokeExpr* expr)
// We will report errors for this one candidate, then, to give
// the user the most help we can.
if (shouldAddToCache)
typeCheckingCache->resolvedOperatorOverloadCache[key] = *context.bestCandidate;
{
if (isFromCoreModule(context.bestCandidate->item.declRef.getDecl()) ||
getShared()->glslModuleDecl ==
getModuleDecl(context.bestCandidate->item.declRef.getDecl()))
{
ResolvedOperatorOverload overloadResult;
overloadResult.candidate = *context.bestCandidate;
overloadResult.decl = context.bestCandidate->item.declRef.getDecl();
overloadResult.cacheVersion = typeCheckingCache->version;
typeCheckingCache->resolvedOperatorOverloadCache[key] = overloadResult;
}
}

// Now that we have resolved the overload candidate, we need to undo an `openExistential`
// operation that was applied to `out` arguments.
5 changes: 4 additions & 1 deletion source/slang/slang-compiler.h
Original file line number Diff line number Diff line change
@@ -2209,7 +2209,7 @@ class Linkage : public RefObject, public slang::ISession
TypeCheckingCache* getTypeCheckingCache();
void destroyTypeCheckingCache();

TypeCheckingCache* m_typeCheckingCache = nullptr;
RefPtr<RefObject> m_typeCheckingCache = nullptr;

// Modules that have been dynamically loaded via `import`
//
@@ -3589,6 +3589,9 @@ class Session : public RefObject, public slang::IGlobalSession

int m_typeDictionarySize = 0;

RefPtr<RefObject> m_typeCheckingCache;
TypeCheckingCache* getTypeCheckingCache();

private:
struct BuiltinModuleInfo
{
3 changes: 2 additions & 1 deletion source/slang/slang-diagnostic-defs.h
Original file line number Diff line number Diff line change
@@ -49,6 +49,7 @@ DIAGNOSTIC(-1, Note, seeUsingOf, "see using of '$0'")
DIAGNOSTIC(-1, Note, seeDefinitionOfShader, "see definition of shader '$0'")
DIAGNOSTIC(-1, Note, seeInclusionOf, "see inclusion of '$0'")
DIAGNOSTIC(-1, Note, seeModuleBeingUsedIn, "see module '$0' being used in '$1'")
DIAGNOSTIC(-1, Note, seeCallOfFunc, "see call to '$0'")
DIAGNOSTIC(-1, Note, seePipelineRequirementDefinition, "see pipeline requirement definition")
DIAGNOSTIC(
-1,
@@ -2309,7 +2310,7 @@ DIAGNOSTIC(
41402,
Error,
staticAssertionConditionNotConstant,
"condition for static assertion cannot be evaluated at the compile-time.")
"condition for static assertion cannot be evaluated at compile time.")

DIAGNOSTIC(
41402,
26 changes: 26 additions & 0 deletions source/slang/slang-emit.cpp
Original file line number Diff line number Diff line change
@@ -475,6 +475,31 @@ void calcRequiredLoweringPassSet(
}
}

void diagnoseCallStack(IRInst* inst, DiagnosticSink* sink)
{
static const int maxDepth = 5;
for (int i = 0; i < maxDepth; i++)
{
auto func = getParentFunc(inst);
if (!func)
return;
bool shouldContinue = false;
for (auto use = func->firstUse; use; use = use->nextUse)
{
auto user = use->getUser();
if (auto call = as<IRCall>(user))
{
sink->diagnose(call, Diagnostics::seeCallOfFunc, func);
inst = call;
shouldContinue = true;
break;
}
}
if (!shouldContinue)
return;
}
}

bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink)
{
switch (inst->getOp())
@@ -498,6 +523,7 @@ bool checkStaticAssert(IRInst* inst, DiagnosticSink* sink)
{
sink->diagnose(inst, Diagnostics::staticAssertionFailureWithoutMessage);
}
diagnoseCallStack(inst, sink);
}
}
else
1 change: 1 addition & 0 deletions source/slang/slang-ir-autodiff-fwd.cpp
Original file line number Diff line number Diff line change
@@ -2012,6 +2012,7 @@ InstPair ForwardDiffTranscriber::transcribeInstImpl(IRBuilder* builder, IRInst*
case kIROp_MakeArrayFromElement:
case kIROp_MakeTuple:
case kIROp_MakeValuePack:
case kIROp_BuiltinCast:
return transcribeConstruct(builder, origInst);
case kIROp_MakeStruct:
return transcribeMakeStruct(builder, origInst);
Loading