Skip to content

Commit cc1b96d

Browse files
authored
Check mismatching method parameter direction against interface declaration. (shader-slang#5964)
1 parent 88e221b commit cc1b96d

11 files changed

+244
-11
lines changed

source/slang/slang-ast-support-types.cpp

+25
Original file line numberDiff line numberDiff line change
@@ -70,4 +70,29 @@ UnownedStringSlice getHigherOrderOperatorName(HigherOrderInvokeExpr* expr)
7070
return UnownedStringSlice();
7171
}
7272

73+
void printDiagnosticArg(StringBuilder& sb, ParameterDirection direction)
74+
{
75+
switch (direction)
76+
{
77+
case kParameterDirection_In:
78+
sb << "in";
79+
break;
80+
case kParameterDirection_Out:
81+
sb << "out";
82+
break;
83+
case kParameterDirection_Ref:
84+
sb << "ref";
85+
break;
86+
case kParameterDirection_InOut:
87+
sb << "inout";
88+
break;
89+
case kParameterDirection_ConstRef:
90+
sb << "constref";
91+
break;
92+
default:
93+
sb << "(" << int(direction) << ")";
94+
break;
95+
}
96+
}
97+
7398
} // namespace Slang

source/slang/slang-ast-support-types.h

+2
Original file line numberDiff line numberDiff line change
@@ -1634,6 +1634,8 @@ enum ParameterDirection
16341634
kParameterDirection_ConstRef, ///< By-const-reference
16351635
};
16361636

1637+
void printDiagnosticArg(StringBuilder& sb, ParameterDirection direction);
1638+
16371639
/// The kind of a builtin interface requirement that can be automatically synthesized.
16381640
enum class BuiltinRequirementKind
16391641
{

source/slang/slang-capability.cpp

-1
Original file line numberDiff line numberDiff line change
@@ -1165,7 +1165,6 @@ void printDiagnosticArg(StringBuilder& sb, List<CapabilityAtom>& list)
11651165
printDiagnosticArg(sb, set.newSetWithoutImpliedAtoms());
11661166
}
11671167

1168-
11691168
#ifdef UNIT_TEST_CAPABILITIES
11701169

11711170
#define CHECK_CAPS(inData) SLANG_ASSERT(inData > 0)

source/slang/slang-check-decl.cpp

+95-6
Original file line numberDiff line numberDiff line change
@@ -3412,7 +3412,9 @@ bool SemanticsVisitor::doesSignatureMatchRequirement(
34123412
{
34133413
auto requiredParam = requiredParams[paramIndex];
34143414
auto satisfyingParam = satisfyingParams[paramIndex];
3415-
3415+
if (getParameterDirection(requiredParam.getDecl()) !=
3416+
getParameterDirection(satisfyingParam.getDecl()))
3417+
return false;
34163418
auto requiredParamType = getType(m_astBuilder, requiredParam);
34173419
auto satisfyingParamType = getType(m_astBuilder, satisfyingParam);
34183420

@@ -4394,14 +4396,14 @@ void SemanticsVisitor::addRequiredParamsToSynthesizedDecl(
43944396
//
43954397
for (auto paramDeclRef : getParameters(m_astBuilder, requirement))
43964398
{
4397-
auto paramType = getType(m_astBuilder, paramDeclRef);
4399+
auto paramType = QualType(getType(m_astBuilder, paramDeclRef));
43984400

43994401
// For each parameter of the requirement, we create a matching
44004402
// parameter (same name and type) for the synthesized method.
44014403
//
44024404
auto synParamDecl = m_astBuilder->create<ParamDecl>();
44034405
synParamDecl->nameAndLoc = paramDeclRef.getDecl()->nameAndLoc;
4404-
synParamDecl->type.type = paramType;
4406+
synParamDecl->type.type = paramType.type;
44054407

44064408
// We need to add the parameter as a child declaration of
44074409
// the method we are building.
@@ -4410,6 +4412,7 @@ void SemanticsVisitor::addRequiredParamsToSynthesizedDecl(
44104412
synthesized->members.add(synParamDecl);
44114413

44124414
// Add modifiers
4415+
paramType.isLeftValue = true;
44134416
for (auto modifier : paramDeclRef.getDecl()->modifiers)
44144417
{
44154418
if (as<NoDiffModifier>(modifier))
@@ -4426,6 +4429,8 @@ void SemanticsVisitor::addRequiredParamsToSynthesizedDecl(
44264429
(Modifier*)m_astBuilder->createByNodeType(modifier->astNodeType);
44274430
clonedModifier->keywordName = modifier->keywordName;
44284431
addModifier(synParamDecl, clonedModifier);
4432+
if (as<ConstRefModifier>(modifier))
4433+
paramType.isLeftValue = false;
44294434
}
44304435
}
44314436

@@ -4445,6 +4450,7 @@ void SemanticsVisitor::addRequiredParamsToSynthesizedDecl(
44454450
synMemberExpr->base = synArg;
44464451
synMemberExpr->elementIndices.add((uint32_t)i);
44474452
synMemberExpr->type = elementType;
4453+
synMemberExpr->type.isLeftValue = paramType.isLeftValue;
44484454
synArgs.add(synMemberExpr);
44494455
}
44504456
}
@@ -4572,6 +4578,22 @@ static bool isWrapperTypeDecl(Decl* decl)
45724578
return false;
45734579
}
45744580

4581+
// Is it allowed to have an interface method parameter whose direction is `reqDir`, and an
4582+
// implementing method parameter whose direction is `implDir`?
4583+
//
4584+
static bool matchParamDirection(ParameterDirection implDir, ParameterDirection reqDir)
4585+
{
4586+
// If the parameter directions match exactly, then we are good.
4587+
if (implDir == reqDir)
4588+
return true;
4589+
// Otherwise, we only allow the cases where reqDir is `InOut` and implDir is `In` or `Out`.
4590+
if (implDir == kParameterDirection_In && reqDir == kParameterDirection_InOut)
4591+
return true;
4592+
if (implDir == kParameterDirection_Out && reqDir == kParameterDirection_InOut)
4593+
return true;
4594+
return false;
4595+
}
4596+
45754597
bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
45764598
ConformanceCheckingContext* context,
45774599
LookupResult const& lookupResult,
@@ -4749,7 +4771,13 @@ bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
47494771
// If checking the generic app failed, we can't synthesize the witness.
47504772
//
47514773
if (tempSink.getErrorCount() != 0)
4774+
{
4775+
context->innerSink.diagnose(
4776+
SourceLoc(),
4777+
Diagnostics::genericSignatureDoesNotMatchRequirement,
4778+
baseOverloadedExpr->name);
47524779
return false;
4780+
}
47534781
}
47544782

47554783
// We now have the reference to the overload group we plan to call,
@@ -4791,11 +4819,67 @@ bool SemanticsVisitor::trySynthesizeMethodRequirementWitness(
47914819
// diagnose a generic "failed to satisfying requirement" error.
47924820
//
47934821
if (tempSink.getErrorCount() != 0)
4822+
{
4823+
context->innerSink.diagnose(
4824+
SourceLoc(),
4825+
Diagnostics::cannotResolveOverloadForMethodRequirement,
4826+
baseOverloadedExpr->name);
47944827
return false;
4828+
}
47954829

4796-
// If we were able to type-check the call, then we should
4797-
// be able to finish construction of a suitable witness.
4830+
// If we were able to type-check the call, we also need to make
4831+
// sure that the resolved callee member has consistent parameter
4832+
// direction as the requirement method.
4833+
//
4834+
// For example, if there is a requirement:
4835+
// ```
4836+
// interface IFoo { void method(out int x); }
4837+
// ```
4838+
// and a type:
4839+
// ```
4840+
// struct X : IFoo { void method(int x) { ... } }
4841+
// ```
4842+
// After we synthesize:
4843+
// ```
4844+
// void X::synthesized_method(out int x) { this.method(x); }
4845+
// ```
4846+
// The synthesized method will pass all type check just fine,
4847+
// but we don't want to allow this method to be used as a witness
4848+
// for the requirement due to inconsistent parameter direction.
4849+
// So let's check for this now.
47984850
//
4851+
if (auto checkedInvoke = as<InvokeExpr>(checkedCall))
4852+
{
4853+
if (auto declRefExpr = as<DeclRefExpr>(checkedInvoke->functionExpr))
4854+
{
4855+
if (auto callee = as<CallableDecl>(declRefExpr->declRef))
4856+
{
4857+
auto synParams = synFuncDecl->getParameters();
4858+
auto calleeParams = callee.getDecl()->getParameters();
4859+
auto synParamIter = synParams.begin();
4860+
auto calleeParamIter = calleeParams.begin();
4861+
for (; synParamIter != synParams.end() && calleeParamIter != calleeParams.end();
4862+
++synParamIter, ++calleeParamIter)
4863+
{
4864+
auto synParam = *synParamIter;
4865+
auto calleeParam = *calleeParamIter;
4866+
if (!matchParamDirection(
4867+
getParameterDirection(calleeParam),
4868+
getParameterDirection(synParam)))
4869+
{
4870+
context->innerSink.diagnose(
4871+
calleeParam,
4872+
Diagnostics::parameterDirectionDoesNotMatchRequirement,
4873+
calleeParam,
4874+
getParameterDirection(calleeParam),
4875+
getParameterDirection(synParam));
4876+
return false;
4877+
}
4878+
}
4879+
}
4880+
}
4881+
}
4882+
47994883
// We've already created the outer declaration (including its
48004884
// parameters), and the inner expression, so the main work
48014885
// that is left is defining the body of the new function,
@@ -5293,7 +5377,7 @@ bool SemanticsVisitor::trySynthesizeWrapperTypePropertyRequirementWitness(
52935377
base->name = getName("inner");
52945378
propertyRef->baseExpression = base;
52955379
innerProperty = innerAccessorDeclRef.getParent();
5296-
propertyRef->name = getParentDecl(innerAccessorDeclRef.getDecl())->getName();
5380+
propertyRef->name = requiredMemberDeclRef.getName();
52975381
auto checkedPropertyRefExpr = CheckExpr(propertyRef);
52985382

52995383
if (as<GetterDecl>(requiredAccessorDeclRef))
@@ -6608,6 +6692,7 @@ bool SemanticsVisitor::findWitnessForInterfaceRequirement(
66086692
// a wrapper type (struct Foo:IFoo=FooImpl), and we will synthesize
66096693
// wrappers that redirects the call into the inner element.
66106694
//
6695+
context->innerSink.reset();
66116696
if (trySynthesizeRequirementWitness(context, lookupResult, requiredMemberDeclRef, witnessTable))
66126697
{
66136698
return true;
@@ -6638,6 +6723,10 @@ bool SemanticsVisitor::findWitnessForInterfaceRequirement(
66386723
subType,
66396724
requiredMemberDeclRef);
66406725
}
6726+
if (context->innerSink.outputBuffer.getLength())
6727+
{
6728+
getSink()->diagnoseRaw(Severity::Note, context->innerSink.outputBuffer.getUnownedSlice());
6729+
}
66416730
getSink()->diagnose(
66426731
requiredMemberDeclRef,
66436732
Diagnostics::seeDeclarationOfInterfaceRequirement,

source/slang/slang-check-impl.h

+9
Original file line numberDiff line numberDiff line change
@@ -1143,6 +1143,12 @@ struct OuterScopeContextRAII
11431143
context, \
11441144
decl->ownedScope ? decl->ownedScope : context->getOuterScope())
11451145

1146+
struct RequirementSynthesisResult
1147+
{
1148+
bool suceeded = false;
1149+
operator bool() const { return suceeded; }
1150+
};
1151+
11461152
struct SemanticsVisitor : public SemanticsContext
11471153
{
11481154
typedef SemanticsContext Super;
@@ -1742,6 +1748,9 @@ struct SemanticsVisitor : public SemanticsContext
17421748
/// declaration)
17431749
ContainerDecl* parentDecl;
17441750

1751+
// An inner diagnostic sink to store diagnostics about why requirement synthesis failed.
1752+
DiagnosticSink innerSink;
1753+
17451754
Dictionary<DeclRef<InterfaceDecl>, RefPtr<WitnessTable>> mapInterfaceToWitnessTable;
17461755
};
17471756

source/slang/slang-diagnostic-defs.h

+19
Original file line numberDiff line numberDiff line change
@@ -70,6 +70,25 @@ DIAGNOSTIC(
7070
Note,
7171
seeDeclarationOfInterfaceRequirement,
7272
"see interface requirement declaration of '$0'")
73+
74+
DIAGNOSTIC(
75+
-1,
76+
Note,
77+
genericSignatureDoesNotMatchRequirement,
78+
"generic signature of '$0' does not match interface requirement.")
79+
80+
DIAGNOSTIC(
81+
-1,
82+
Note,
83+
cannotResolveOverloadForMethodRequirement,
84+
"none of the overloads of '$0' match the interface requirement.")
85+
86+
DIAGNOSTIC(
87+
-1,
88+
Note,
89+
parameterDirectionDoesNotMatchRequirement,
90+
"parameter '$0' is '$1' in the implementing member, but the interface requires '$2'.")
91+
7392
// An alternate wording of the above note, emphasing the position rather than content of the
7493
// declaration.
7594
DIAGNOSTIC(-1, Note, declaredHere, "declared here")

tests/diagnostics/interfaces/mutating-impl-of-non-mutating-req.slang

+3-3
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
// mutating-impl-of-non-mutating-req.slang
22

3-
//DIAGNOSTIC_TEST:SIMPLE:-target hlsl -entry main
4-
3+
//DIAGNOSTIC_TEST:SIMPLE(filecheck=CHECK):-target hlsl -entry main
54
interface IThing
65
{
76
int processValue(int inValue);
@@ -10,7 +9,8 @@ interface IThing
109
struct Counter : IThing
1110
{
1211
int state;
13-
12+
13+
// CHECK: ([[# @LINE+1]]): error 38105:
1414
[mutating] int processValue(int inValue)
1515
{
1616
int result = state;
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//TEST:SIMPLE(filecheck=CHECK): -target spirv
2+
3+
public interface ITest {
4+
public void testIn(int a);
5+
public void testOut(out int b);
6+
};
7+
8+
public struct TestImpl : ITest {
9+
// CHECK: ([[# @LINE + 1]]): error 38105
10+
public void testIn(out int a) {
11+
a = 5;
12+
}
13+
// CHECK: ([[# @LINE + 1]]): error 38105
14+
public void testOut(int b) {
15+
b = 6;
16+
}
17+
}
18+
19+
RWStructuredBuffer<int> output;
20+
21+
void doSomething<T>(T data) where T : ITest {
22+
int a = 516;
23+
data.testIn(a);
24+
int b = 687;
25+
data.testOut(b);
26+
27+
output[0] = a;
28+
output[1] = b;
29+
}
30+
31+
[shader("compute")]
32+
[numthreads(1,1,1)]
33+
void computeMain()
34+
{
35+
TestImpl data;
36+
doSomething(data);
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type
2+
3+
public interface ITest {
4+
public int testDir(inout int a);
5+
};
6+
7+
public struct TestImpl : ITest {
8+
public int testDir(inout int a) {
9+
int oldA = a;
10+
a = 5;
11+
return a;
12+
}
13+
}
14+
15+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=output
16+
RWStructuredBuffer<int> output;
17+
18+
public struct Test : ITest = TestImpl;
19+
20+
[shader("compute")]
21+
[numthreads(1,1,1)]
22+
void computeMain()
23+
{
24+
Test data;
25+
int a = 516;
26+
int b = data.testDir(a);
27+
// CHECK: 5
28+
output[0] = b;
29+
}

0 commit comments

Comments
 (0)