Skip to content

Commit

Permalink
Fix interface requirement lowering for generic accessors
Browse files Browse the repository at this point in the history
  • Loading branch information
saipraveenb25 committed Jan 17, 2025
1 parent f68d493 commit ef5b233
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
10 changes: 10 additions & 0 deletions source/slang/slang-lower-to-ir.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8633,6 +8633,9 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
UInt operandCount = 0;
for (auto requirementDecl : decl->members)
{
if (as<GenericDecl>(requirementDecl))
requirementDecl = getInner(requirementDecl);

if (as<SubscriptDecl>(requirementDecl) || as<PropertyDecl>(requirementDecl))
{
for (auto accessorDecl : as<ContainerDecl>(requirementDecl)->members)
Expand Down Expand Up @@ -8782,6 +8785,13 @@ struct DeclLoweringVisitor : DeclVisitor<DeclLoweringVisitor, LoweredValInfo>
auto requirementKey = getInterfaceRequirementKey(requirementDecl);
if (!requirementKey)
{
if (auto genericDecl = as<GenericDecl>(requirementDecl))
{
// We need to form a declref into the inner decls in case of a generic
// requirement.
requirementDecl = getInner(genericDecl);
}

if (as<PropertyDecl>(requirementDecl) || as<SubscriptDecl>(requirementDecl))
{
for (auto member : as<ContainerDecl>(requirementDecl)->members)
Expand Down
32 changes: 32 additions & 0 deletions tests/autodiff/generic-accessors.slang
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHK): -output-using-type

interface ITest
{
__generic<I : IInteger>
__subscript(I i) -> float
{
[BackwardDifferentiable] get;
}
}
struct Test : ITest
{
__generic<I : IInteger>
__subscript(I i) -> float
{
[BackwardDifferentiable] get { return 5.0f * i.toInt(); }
}
}
[Differentiable]
float test(ITest arg)
{
return arg[1];
}
//TEST_INPUT:set output = out ubuffer(data=[0 0 0 0], stride=4)
RWStructuredBuffer<float> output;
[numthreads(1,1,1)]
void computeMain()
{
Test t = {};
output[0] = test(t);
// CHK: 5.0
}

0 comments on commit ef5b233

Please sign in to comment.