@@ -143,6 +143,8 @@ struct SemanticsDeclHeaderVisitor : public SemanticsDeclVisitorBase,
143
143
144
144
void visitAssocTypeDecl(AssocTypeDecl* decl);
145
145
146
+ void checkDifferentiableCallableCommon(CallableDecl* decl);
147
+
146
148
void checkCallableDeclCommon(CallableDecl* decl);
147
149
148
150
void visitFuncDecl(FuncDecl* funcDecl);
@@ -9109,24 +9111,8 @@ void SemanticsDeclHeaderVisitor::setFuncTypeIntoRequirementDecl(
9109
9111
}
9110
9112
}
9111
9113
9112
- void SemanticsDeclHeaderVisitor::checkCallableDeclCommon (CallableDecl* decl)
9114
+ void SemanticsDeclHeaderVisitor::checkDifferentiableCallableCommon (CallableDecl* decl)
9113
9115
{
9114
- for (auto paramDecl : decl->getParameters())
9115
- {
9116
- ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
9117
- }
9118
-
9119
- auto errorType = decl->errorType;
9120
- if (errorType.exp)
9121
- {
9122
- errorType = CheckProperType(errorType);
9123
- }
9124
- else
9125
- {
9126
- errorType = TypeExp(m_astBuilder->getBottomType());
9127
- }
9128
- decl->errorType = errorType;
9129
-
9130
9116
if (auto interfaceDecl = findParentInterfaceDecl(decl))
9131
9117
{
9132
9118
bool isDiffFunc = false;
@@ -9248,6 +9234,27 @@ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
9248
9234
}
9249
9235
}
9250
9236
}
9237
+ }
9238
+
9239
+ void SemanticsDeclHeaderVisitor::checkCallableDeclCommon(CallableDecl* decl)
9240
+ {
9241
+ for (auto paramDecl : decl->getParameters())
9242
+ {
9243
+ ensureDecl(paramDecl, DeclCheckState::ReadyForReference);
9244
+ }
9245
+
9246
+ auto errorType = decl->errorType;
9247
+ if (errorType.exp)
9248
+ {
9249
+ errorType = CheckProperType(errorType);
9250
+ }
9251
+ else
9252
+ {
9253
+ errorType = TypeExp(m_astBuilder->getBottomType());
9254
+ }
9255
+ decl->errorType = errorType;
9256
+
9257
+ checkDifferentiableCallableCommon(decl);
9251
9258
9252
9259
// If this method is intended to be a CUDA kernel, verify that the return type is void.
9253
9260
if (decl->findModifier<CudaKernelAttribute>())
@@ -9709,6 +9716,8 @@ void SemanticsDeclHeaderVisitor::visitAccessorDecl(AccessorDecl* decl)
9709
9716
// for `GetterDecl`s.
9710
9717
//
9711
9718
decl->returnType.type = _getAccessorStorageType(decl);
9719
+
9720
+ checkDifferentiableCallableCommon(decl);
9712
9721
}
9713
9722
9714
9723
void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl)
@@ -9799,6 +9808,7 @@ void SemanticsDeclHeaderVisitor::visitSetterDecl(SetterDecl* decl)
9799
9808
newValueType);
9800
9809
}
9801
9810
}
9811
+ checkDifferentiableCallableCommon(decl);
9802
9812
}
9803
9813
9804
9814
GenericDecl* SemanticsVisitor::GetOuterGeneric(Decl* decl)
0 commit comments