Skip to content

Commit e93cb8a

Browse files
authoredDec 20, 2024
Check subscript/property accessor for differentiability. (shader-slang#5922)
1 parent 5c9f011 commit e93cb8a

File tree

3 files changed

+73
-18
lines changed

3 files changed

+73
-18
lines changed
 

‎source/slang/slang-check-decl.cpp

+27-17
Original file line numberDiff line numberDiff line change
@@ -143,6 +143,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase,
143143

144144
void visitAssocTypeDecl(AssocTypeDecl* decl);
145145

146+
void checkDifferentiableCallableCommon(CallableDecl* decl);
147+
146148
void checkCallableDeclCommon(CallableDecl* decl);
147149

148150
void visitFuncDecl(FuncDecl* funcDecl);
@@ -9109,24 +9111,8 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(
91099111
}
91109112
}
91119113

9112-
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
9114+
void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon(CallableDecl* decl)
91139115
{
9114-
for (auto paramDecl : decl->getParameters())
9115-
{
9116-
ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
9117-
}
9118-
9119-
auto errorType = decl->errorType;
9120-
if (errorType.exp)
9121-
{
9122-
errorType = CheckProperType(errorType);
9123-
}
9124-
else
9125-
{
9126-
errorType = TypeExp(m_astBuilder->getBottomType());
9127-
}
9128-
decl->errorType = errorType;
9129-
91309116
if (auto interfaceDecl = findParentInterfaceDecl(decl))
91319117
{
91329118
bool isDiffFunc = false;
@@ -9248,6 +9234,27 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
92489234
}
92499235
}
92509236
}
9237+
}
9238+
9239+
void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
9240+
{
9241+
for (auto paramDecl : decl->getParameters())
9242+
{
9243+
ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
9244+
}
9245+
9246+
auto errorType = decl->errorType;
9247+
if (errorType.exp)
9248+
{
9249+
errorType = CheckProperType(errorType);
9250+
}
9251+
else
9252+
{
9253+
errorType = TypeExp(m_astBuilder->getBottomType());
9254+
}
9255+
decl->errorType = errorType;
9256+
9257+
checkDifferentiableCallableCommon(decl);
92519258

92529259
// If this method is intended to be a CUDA kernel, verify that the return type is void.
92539260
if (decl->findModifier<CudaKernelAttribute>())
@@ -9709,6 +9716,8 @@ void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl)
97099716
// for `GetterDecl`s.
97109717
//
97119718
decl->returnType.type = _getAccessorStorageType(decl);
9719+
9720+
checkDifferentiableCallableCommon(decl);
97129721
}
97139722

97149723
void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl)
@@ -9799,6 +9808,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl)
97999808
newValueType);
98009809
}
98019810
}
9811+
checkDifferentiableCallableCommon(decl);
98029812
}
98039813

98049814
GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl)

‎source/slang/slang-syntax.cpp

+13-1
Original file line numberDiff line numberDiff line change
@@ -851,7 +851,7 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR
851851
List<Type*> paramTypes;
852852
auto resultType = getResultType(astBuilder, declRef);
853853
auto errorType = getErrorCodeType(astBuilder, declRef);
854-
for (auto paramDeclRef : getParameters(astBuilder, declRef))
854+
auto visitParamDecl = [&](DeclRef<ParamDecl> paramDeclRef)
855855
{
856856
auto paramDecl = paramDeclRef.getDecl();
857857
auto paramType = getParamType(astBuilder, paramDeclRef);
@@ -875,6 +875,18 @@ FuncType* getFuncType(ASTBuilder* astBuilder, DeclRef<CallableDecl> const& declR
875875
}
876876
}
877877
paramTypes.add(paramType);
878+
};
879+
auto parent = declRef.getParent();
880+
if (as<SubscriptDecl>(parent) || as<PropertyDecl>(parent))
881+
{
882+
for (auto paramDeclRef : getParameters(astBuilder, parent.as<CallableDecl>()))
883+
{
884+
visitParamDecl(paramDeclRef);
885+
}
886+
}
887+
for (auto paramDeclRef : getParameters(astBuilder, declRef))
888+
{
889+
visitParamDecl(paramDeclRef);
878890
}
879891

880892
FuncType* funcType =

‎tests/autodiff/subscript.slang

+33
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHK): -output-using-type
2+
3+
interface ITest
4+
{
5+
__subscript(int i) -> float
6+
{
7+
[BackwardDifferentiable] get;
8+
}
9+
}
10+
struct Test : ITest
11+
{
12+
__subscript(int i) -> float
13+
{
14+
[BackwardDifferentiable] get { return 5.0f * i; }
15+
}
16+
}
17+
18+
[Differentiable]
19+
float test(ITest arg)
20+
{
21+
return arg[1];
22+
}
23+
24+
//TEST_INPUT:set output = out ubuffer(data=[0 0 0 0], stride=4)
25+
RWStructuredBuffer<float> output;
26+
27+
[numthreads(1,1,1)]
28+
void computeMain()
29+
{
30+
Test t = {};
31+
output[0] = test(t);
32+
// CHK: 5.0
33+
}

0 commit comments

Comments
 (0)