Skip to content

Commit 3d435f7

Browse files
csyongheTim Foley
authored and
Tim Foley
committed
Bug fixes for Slang integration (shader-slang#356)
* fix shader-slang#353 * move validateEntryPoint to after all entrypoints has been checked * bug fix: DeclRefType::SubstituteImpl should change ioDiff * bug fix: generic resource usage should have count of 1 instead of 0. * update test case
1 parent e90dfcf commit 3d435f7

9 files changed

+101
-23
lines changed

source/slang/check.cpp

+1-16
Original file line numberDiff line numberDiff line change
@@ -6830,7 +6830,7 @@ namespace Slang
68306830
entryPoint->genericParameterTypes.Count());
68316831
return;
68326832
}
6833-
// if number of entry-point type arguments matches parameters, try find
6833+
// if entry-point type arguments matches parameters, try find
68346834
// SubtypeWitness for each argument
68356835
int index = 0;
68366836
for (auto & gParam : globalGenericParams)
@@ -6871,21 +6871,6 @@ namespace Slang
68716871
// checking that is required on all declarations
68726872
// in the translation unit.
68736873
visitor.checkDecl(translationUnit->SyntaxNode);
6874-
6875-
// Next, do follow-up validation on any entry
6876-
// points that the user declared via API or
6877-
// command line, to ensure that they meet
6878-
// requirements.
6879-
//
6880-
// Note: We may eventually have syntax to
6881-
// identify entry points via a modifier on
6882-
// declarations, and in this case they should
6883-
// probably get validated as part of orindary
6884-
// checking above.
6885-
for (auto entryPoint : translationUnit->entryPoints)
6886-
{
6887-
validateEntryPoint(entryPoint);
6888-
}
68896874
}
68906875

68916876

source/slang/ir.cpp

+20-2
Original file line numberDiff line numberDiff line change
@@ -3095,6 +3095,12 @@ namespace Slang
30953095
{
30963096
return declRef;
30973097
}
3098+
3099+
// A callback used to clone (or not) a Val
3100+
virtual RefPtr<Val> maybeCloneVal(Val* val)
3101+
{
3102+
return val;
3103+
}
30983104
};
30993105

31003106
void registerClonedValue(
@@ -3203,6 +3209,7 @@ namespace Slang
32033209
virtual DeclRef<Decl> maybeCloneDeclRef(DeclRef<Decl> const& declRef) override;
32043210

32053211
virtual RefPtr<Type> maybeCloneType(Type* originalType) override;
3212+
virtual RefPtr<Val> maybeCloneVal(Val* val) override;
32063213
};
32073214

32083215

@@ -3216,6 +3223,11 @@ namespace Slang
32163223
return originalType->Substitute(subst).As<Type>();
32173224
}
32183225

3226+
RefPtr<Val> IRSpecContext::maybeCloneVal(Val * val)
3227+
{
3228+
return val->Substitute(subst);
3229+
}
3230+
32193231
IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue)
32203232
{
32213233
switch (originalValue->op)
@@ -3316,7 +3328,7 @@ namespace Slang
33163328
}
33173329
else
33183330
{
3319-
return val;
3331+
return context->maybeCloneVal(val);
33203332
}
33213333
}
33223334

@@ -3439,7 +3451,7 @@ namespace Slang
34393451
IRGlobalVar* originalVar,
34403452
IROriginalValuesForClone const& originalValues)
34413453
{
3442-
auto clonedVar = context->builder->createGlobalVar(context->maybeCloneType(originalVar->getType()->getValueType()));
3454+
auto clonedVar = context->builder->createGlobalVar(context->maybeCloneType(originalVar->getType()->getValueType()));
34433455
registerClonedValue(context, clonedVar, originalValues);
34443456

34453457
auto mangledName = originalVar->mangledName;
@@ -4229,6 +4241,7 @@ namespace Slang
42294241
virtual IRValue* maybeCloneValue(IRValue* originalVal) override;
42304242

42314243
virtual RefPtr<Type> maybeCloneType(Type* originalType) override;
4244+
virtual RefPtr<Val> maybeCloneVal(Val* val) override;
42324245
};
42334246

42344247
// Convert a type-level value into an IR-level equivalent.
@@ -4352,6 +4365,11 @@ namespace Slang
43524365
return originalType->Substitute(subst).As<Type>();
43534366
}
43544367

4368+
RefPtr<Val> IRGenericSpecContext::maybeCloneVal(Val * val)
4369+
{
4370+
return val->Substitute(subst);
4371+
}
4372+
43554373
// Given a list of substitutions, return the inner-most
43564374
// generic substitution in the list, or NULL if there
43574375
// are no generic substitutions.

source/slang/parameter-binding.cpp

+3-4
Original file line numberDiff line numberDiff line change
@@ -1077,7 +1077,7 @@ static void completeBindingsForParameter(
10771077
else if (kind == LayoutResourceKind::GenericResource)
10781078
{
10791079
bindingInfo.space = 0;
1080-
bindingInfo.count = 0;
1080+
bindingInfo.count = 1;
10811081
bindingInfo.index = 0;
10821082
continue;
10831083
}
@@ -2118,8 +2118,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
21182118
for (auto & varLayout : globalStructLayout->fields)
21192119
{
21202120
// To recover layout context, we skip generic resources in the first pass
2121-
// If the var is a generic resource, its resourceInfos will be empty.
2122-
if (varLayout->resourceInfos.Count() == 0)
2121+
if (varLayout->FindResourceInfo(LayoutResourceKind::GenericResource))
21232122
continue;
21242123
SLANG_ASSERT(varLayout->resourceInfos.Count() == varLayout->typeLayout->resourceInfos.Count());
21252124
auto uniformInfo = varLayout->FindResourceInfo(LayoutResourceKind::Uniform);
@@ -2140,7 +2139,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
21402139
usedRangeSet->usedResourceRanges[(int)resInfo.kind].Add(
21412140
nullptr, // we don't need to track parameter info here
21422141
resInfo.index,
2143-
resInfo.index + varLayout->typeLayout->resourceInfos[0].count);
2142+
resInfo.index + tresInfo.count);
21442143
}
21452144
structLayout->fields.Add(varLayout);
21462145
varLayoutMapping[varLayout] = varLayout;

source/slang/slang.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,8 @@ void CompileRequest::parseTranslationUnit(
174174
}
175175
}
176176

177+
void validateEntryPoint(EntryPointRequest*);
178+
177179
void CompileRequest::checkAllTranslationUnits()
178180
{
179181
// Iterate over all translation units and
@@ -182,6 +184,24 @@ void CompileRequest::checkAllTranslationUnits()
182184
{
183185
checkTranslationUnit(translationUnit.Ptr());
184186
}
187+
188+
for (auto& translationUnit : translationUnits)
189+
{
190+
// Next, do follow-up validation on any entry
191+
// points that the user declared via API or
192+
// command line, to ensure that they meet
193+
// requirements.
194+
//
195+
// Note: We may eventually have syntax to
196+
// identify entry points via a modifier on
197+
// declarations, and in this case they should
198+
// probably get validated as part of orindary
199+
// checking above.
200+
for (auto entryPoint : translationUnit->entryPoints)
201+
{
202+
validateEntryPoint(entryPoint);
203+
}
204+
}
185205
}
186206

187207
void CompileRequest::generateIR()

source/slang/syntax.cpp

+17-1
Original file line numberDiff line numberDiff line change
@@ -358,6 +358,21 @@ void Type::accept(IValVisitor* visitor, void* extra)
358358
return (ArrayLength == arrType->ArrayLength && baseType->Equals(arrType->baseType.Ptr()));
359359
}
360360

361+
RefPtr<Val> ArrayExpressionType::SubstituteImpl(Substitutions* subst, int* ioDiff)
362+
{
363+
int diff = 0;
364+
auto elementType = baseType->SubstituteImpl(subst, &diff).As<Type>();
365+
if (diff)
366+
{
367+
*ioDiff = 1;
368+
auto rsType = getArrayType(
369+
elementType,
370+
ArrayLength);
371+
return rsType;
372+
}
373+
return this;
374+
}
375+
361376
Type* ArrayExpressionType::CreateCanonicalType()
362377
{
363378
auto canonicalElementType = baseType->GetCanonicalType();
@@ -531,6 +546,7 @@ void Type::accept(IValVisitor* visitor, void* extra)
531546
{
532547
if (genericSubst->paramDecl == globalGenParam)
533548
{
549+
(*ioDiff)++;
534550
return genericSubst->actualType;
535551
}
536552
}
@@ -1393,7 +1409,7 @@ void Type::accept(IValVisitor* visitor, void* extra)
13931409
{
13941410
int diff = 0;
13951411
RefPtr<Substitutions> substSubst = substituteSubstitutions(substitutions, subst, &diff);
1396-
1412+
13971413
if (!diff)
13981414
return *this;
13991415

source/slang/type-defs.h

+1
Original file line numberDiff line numberDiff line change
@@ -310,6 +310,7 @@ RAW(
310310
protected:
311311
virtual bool EqualsImpl(Type * type) override;
312312
virtual Type* CreateCanonicalType() override;
313+
virtual RefPtr<Val> SubstituteImpl(Substitutions* subst, int* ioDiff) override;
313314
virtual int GetHashCode() override;
314315
)
315316
END_SYNTAX_CLASS()
+37
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
//TEST(smoke,compute):COMPARE_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT: cbuffer(data=[1.0], stride=4):dxbinding(0),glbinding(0)
3+
//TEST_INPUT: ubuffer(data=[0], stride=4):dxbinding(0),glbinding(0),out
4+
//TEST_INPUT: type Impl
5+
6+
RWStructuredBuffer<float> outputBuffer;
7+
8+
interface IBase
9+
{
10+
float compute();
11+
}
12+
13+
struct Impl : IBase
14+
{
15+
float base; // = 1.0
16+
float compute()
17+
{
18+
return 1.0;
19+
}
20+
};
21+
22+
__generic_param TImpl : IBase;
23+
24+
ParameterBlock<TImpl> impl;
25+
26+
float doCompute<T:IBase>(T t)
27+
{
28+
return t.compute();
29+
}
30+
31+
[numthreads(1, 1, 1)]
32+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
33+
{
34+
uint tid = dispatchThreadID.x;
35+
float outVal = doCompute<TImpl>(impl);
36+
outputBuffer[tid] = outVal;
37+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3F800000

tools/render-test/test.txt

+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3F800000

0 commit comments

Comments
 (0)