@@ -10915,7 +10915,61 @@ void checkDerivativeAttributeImpl(
10915
10915
SemanticsContext::ExprLocalScope scope;
10916
10916
auto ctx = visitor->withExprLocalScope(&scope);
10917
10917
auto subVisitor = SemanticsVisitor(ctx);
10918
- auto checkedFuncExpr = visitor->dispatchExpr(attr->funcExpr, ctx);
10918
+
10919
+ auto exprToCheck = attr->funcExpr;
10920
+
10921
+ // If this is a generic, we want to wrap the call to the derivative method
10922
+ // with the generic parameters of the source.
10923
+ //
10924
+ if (as<GenericDecl>(funcDecl->parentDecl) && !as<GenericAppExpr>(attr->funcExpr))
10925
+ {
10926
+ auto genericDecl = as<GenericDecl>(funcDecl->parentDecl);
10927
+ auto substArgs = getDefaultSubstitutionArgs(ctx.getASTBuilder(), visitor, genericDecl);
10928
+ auto appExpr = ctx.getASTBuilder()->create<GenericAppExpr>();
10929
+
10930
+ Index count = 0;
10931
+ for (auto member : genericDecl->members)
10932
+ {
10933
+ if (as<GenericTypeParamDecl>(member) || as<GenericValueParamDecl>(member) ||
10934
+ as<GenericTypePackParamDecl>(member))
10935
+ count++;
10936
+ }
10937
+
10938
+ appExpr->functionExpr = attr->funcExpr;
10939
+
10940
+ for (auto arg : substArgs)
10941
+ {
10942
+ if (count == 0)
10943
+ break;
10944
+
10945
+ if (auto declRefType = as<DeclRefType>(arg))
10946
+ {
10947
+ auto baseTypeExpr = ctx.getASTBuilder()->create<SharedTypeExpr>();
10948
+ baseTypeExpr->base.type = declRefType;
10949
+ auto baseTypeType = ctx.getASTBuilder()->getOrCreate<TypeType>(declRefType);
10950
+ baseTypeExpr->type.type = baseTypeType;
10951
+
10952
+ appExpr->arguments.add(baseTypeExpr);
10953
+ }
10954
+ else if (auto genericValParam = as<GenericParamIntVal>(arg))
10955
+ {
10956
+ auto declRef = genericValParam->getDeclRef();
10957
+ appExpr->arguments.add(
10958
+ subVisitor
10959
+ .ConstructDeclRefExpr(declRef, nullptr, nullptr, SourceLoc(), nullptr));
10960
+ }
10961
+ else
10962
+ {
10963
+ SLANG_UNEXPECTED("Unhandled substitution arg type");
10964
+ }
10965
+
10966
+ count--;
10967
+ }
10968
+
10969
+ exprToCheck = appExpr;
10970
+ }
10971
+
10972
+ auto checkedFuncExpr = visitor->dispatchExpr(exprToCheck, ctx);
10919
10973
attr->funcExpr = checkedFuncExpr;
10920
10974
if (attr->args.getCount())
10921
10975
attr->args[0] = attr->funcExpr;
@@ -11427,6 +11481,26 @@ void checkDerivativeOfAttributeImpl(
11427
11481
calleeDeclRef = calleeDeclRefExpr->declRef;
11428
11482
11429
11483
auto calleeFunc = as<FunctionDeclBase>(calleeDeclRef.getDecl());
11484
+
11485
+ if (!calleeFunc)
11486
+ {
11487
+ // If we couldn't find a direct function, it might be a generic.
11488
+ if (auto genericDecl = as<GenericDecl>(calleeDeclRef.getDecl()))
11489
+ {
11490
+ calleeFunc = as<FunctionDeclBase>(genericDecl->inner);
11491
+
11492
+ if (as<ErrorType>(resolved->type.type))
11493
+ {
11494
+ // If we can't resolve a type, something went wrong. If we're working with a generic
11495
+ // decl, the most likely cause is a failure of generic argument inference.
11496
+ //
11497
+ visitor->getSink()->diagnose(
11498
+ derivativeOfAttr,
11499
+ Diagnostics::cannotResolveGenericArgumentForDerivativeFunction);
11500
+ }
11501
+ }
11502
+ }
11503
+
11430
11504
if (!calleeFunc)
11431
11505
{
11432
11506
visitor->getSink()->diagnose(
0 commit comments