Skip to content

Commit d6130ba

Browse files
committed
fixup global generic parameters
1. simplify RoundUpToAlignment() 2. add new a render-compute test case to cover the situation where the entry-point interface (parameter/return types of an entry-point function) is dependent on the global generic type. 3. initial fixes to get this test case to compile (but is not producing correct HLSL output yet)
1 parent 3dff5a5 commit d6130ba

9 files changed

+172
-25
lines changed

source/slang/ir.cpp

+22-1
Original file line numberDiff line numberDiff line change
@@ -3103,10 +3103,22 @@ namespace Slang
31033103
IRGlobalVar* cloneGlobalVar(IRSpecContext* context, IRGlobalVar* originalVar);
31043104
IRFunc* cloneFunc(IRSpecContext* context, IRFunc* originalFunc);
31053105
IRWitnessTable* cloneWitnessTable(IRSpecContext* context, IRWitnessTable* originalVar);
3106+
RefPtr<Substitutions> cloneSubstitutions(
3107+
IRSpecContext* context,
3108+
Substitutions* subst);
31063109

31073110
RefPtr<Type> IRSpecContext::maybeCloneType(Type* originalType)
31083111
{
3109-
return originalType->Substitute(subst).As<Type>();
3112+
auto rsType = originalType->Substitute(subst).As<Type>();
3113+
if (auto declRefType = rsType.As<DeclRefType>())
3114+
{
3115+
if (subst)
3116+
{
3117+
auto newSubst = cloneSubstitutions(this, subst);
3118+
insertSubstAtBottom(declRefType->declRef.substitutions, newSubst);
3119+
}
3120+
}
3121+
return rsType;
31103122
}
31113123

31123124
IRValue* IRSpecContext::maybeCloneValue(IRValue* originalValue)
@@ -3243,6 +3255,15 @@ namespace Slang
32433255
newSubst->outer = cloneSubstitutions(context, subst->outer);
32443256
return newSubst;
32453257
}
3258+
else if (auto genTypeSubst = dynamic_cast<GlobalGenericParamSubstitution*>(subst))
3259+
{
3260+
RefPtr<GlobalGenericParamSubstitution> newSubst = new GlobalGenericParamSubstitution();
3261+
newSubst->actualType = genTypeSubst->actualType;
3262+
newSubst->paramDecl = genTypeSubst->paramDecl;
3263+
newSubst->witnessTables = genTypeSubst->witnessTables;
3264+
newSubst->outer = cloneSubstitutions(context, subst->outer);
3265+
return newSubst;
3266+
}
32463267
else
32473268
SLANG_UNREACHABLE("unimplemented cloneSubstitution");
32483269
UNREACHABLE_RETURN(nullptr);

source/slang/parameter-binding.cpp

+26-12
Original file line numberDiff line numberDiff line change
@@ -1425,6 +1425,16 @@ static RefPtr<TypeLayout> processEntryPointParameter(
14251425

14261426
return structLayout;
14271427
}
1428+
else if (auto globalGenericParam = declRef.As<GlobalGenericParamDecl>())
1429+
{
1430+
auto genParamTypeLayout = new GenericParamTypeLayout();
1431+
// we should have already populated ProgramLayout::genericEntryPointParams list at this point,
1432+
// so we can find the index of this generic param decl in the list
1433+
genParamTypeLayout->type = type;
1434+
genParamTypeLayout->paramIndex = findGenericParam(context->shared->programLayout->globalGenericParams, globalGenericParam.getDecl());
1435+
genParamTypeLayout->findOrAddResourceInfo(LayoutResourceKind::GenericResource)->count++;
1436+
return genParamTypeLayout;
1437+
}
14281438
else
14291439
{
14301440
SLANG_UNEXPECTED("unhandled type kind");
@@ -1442,7 +1452,8 @@ static RefPtr<TypeLayout> processEntryPointParameter(
14421452

14431453
static void collectEntryPointParameters(
14441454
ParameterBindingContext* context,
1445-
EntryPointRequest* entryPoint)
1455+
EntryPointRequest* entryPoint,
1456+
Substitutions* typeSubst)
14461457
{
14471458
FuncDecl* entryPointFuncDecl = entryPoint->decl;
14481459
if (!entryPointFuncDecl)
@@ -1507,7 +1518,7 @@ static void collectEntryPointParameters(
15071518
auto paramTypeLayout = processEntryPointParameterWithPossibleSemantic(
15081519
context,
15091520
paramDecl.Ptr(),
1510-
paramDecl->type.type,
1521+
paramDecl->type.type->Substitute(typeSubst).As<Type>(),
15111522
state,
15121523
paramVarLayout);
15131524

@@ -1539,7 +1550,7 @@ static void collectEntryPointParameters(
15391550
auto resultTypeLayout = processEntryPointParameterWithPossibleSemantic(
15401551
context,
15411552
entryPointFuncDecl,
1542-
resultType,
1553+
resultType->Substitute(typeSubst).As<Type>(),
15431554
state,
15441555
resultLayout);
15451556

@@ -1632,7 +1643,7 @@ static void collectParameters(
16321643
for( auto& entryPoint : translationUnit->entryPoints )
16331644
{
16341645
context->stage = entryPoint->profile.GetStage();
1635-
collectEntryPointParameters(context, entryPoint.Ptr());
1646+
collectEntryPointParameters(context, entryPoint.Ptr(), nullptr);
16361647
}
16371648
}
16381649

@@ -1891,13 +1902,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
18911902
newProgramLayout = new ProgramLayout();
18921903
newProgramLayout->bindingForHackSampler = programLayout->bindingForHackSampler;
18931904
newProgramLayout->hackSamplerVar = programLayout->hackSamplerVar;
1894-
for (auto & entryPoint : programLayout->entryPoints)
1895-
{
1896-
RefPtr<EntryPointLayout> newEntryPoint = new EntryPointLayout(*entryPoint);
1897-
// TODO: for now just copy existing entry point layouts, but we eventually need to
1898-
// specialize these as well...
1899-
newProgramLayout->entryPoints.Add(newEntryPoint);
1900-
}
1905+
newProgramLayout->globalGenericParams = programLayout->globalGenericParams;
19011906

19021907
List<RefPtr<TypeLayout>> paramTypeLayouts;
19031908
auto globalStructLayout = getGlobalStructLayout(programLayout);
@@ -1919,7 +1924,7 @@ RefPtr<ProgramLayout> specializeProgramLayout(
19191924
SharedParameterBindingContext sharedContext;
19201925
sharedContext.compileRequest = targetReq->compileRequest;
19211926
sharedContext.defaultLayoutRules = layoutContext.getRulesFamily();
1922-
sharedContext.programLayout = programLayout;
1927+
sharedContext.programLayout = newProgramLayout;
19231928

19241929
// Create a sub-context to collect parameters that get
19251930
// declared into the global scope
@@ -1928,6 +1933,15 @@ RefPtr<ProgramLayout> specializeProgramLayout(
19281933
context.translationUnit = nullptr;
19291934
context.layoutContext = layoutContext;
19301935

1936+
1937+
for (auto & translationUnit : targetReq->compileRequest->translationUnits)
1938+
{
1939+
for (auto & entryPoint : translationUnit->entryPoints)
1940+
{
1941+
collectEntryPointParameters(&context, entryPoint, typeSubst);
1942+
}
1943+
}
1944+
19311945
auto constantBufferRules = context.getRulesFamily()->getConstantBufferRules();
19321946
structLayout->rules = constantBufferRules;
19331947

source/slang/syntax.cpp

+16-1
Original file line numberDiff line numberDiff line change
@@ -1709,7 +1709,22 @@ void Type::accept(IValVisitor* visitor, void* extra)
17091709
return sb.ProduceString();
17101710
}
17111711

1712-
1712+
void insertSubstAtBottom(RefPtr<Substitutions> & substHead, RefPtr<Substitutions> substToInsert)
1713+
{
1714+
if (!substHead)
1715+
{
1716+
substHead = substToInsert;
1717+
return;
1718+
}
1719+
auto subst = substHead;
1720+
RefPtr<Substitutions> lastSubst = subst;
1721+
while (subst->outer)
1722+
{
1723+
lastSubst = subst;
1724+
subst = subst->outer;
1725+
}
1726+
lastSubst->outer = substToInsert;
1727+
}
17131728

17141729
void insertSubstAtTop(DeclRefBase & declRef, RefPtr<Substitutions> substToInsert)
17151730
{

source/slang/syntax.h

+1
Original file line numberDiff line numberDiff line change
@@ -1156,6 +1156,7 @@ namespace Slang
11561156
Session* session,
11571157
Decl* decl);
11581158

1159+
void insertSubstAtBottom(RefPtr<Substitutions> & substHead, RefPtr<Substitutions> substToInsert);
11591160
RefPtr<ThisTypeSubstitution> getNewThisTypeSubst(DeclRefBase & declRef);
11601161
RefPtr<ThisTypeSubstitution> getThisTypeSubst(DeclRefBase & declRef, bool insertSubstEntry);
11611162
void removeSubstitution(DeclRefBase & declRef, RefPtr<Substitutions> subst);

source/slang/type-layout.h

+1-1
Original file line numberDiff line numberDiff line change
@@ -676,7 +676,7 @@ createStructuredBufferTypeLayout(
676676
RefPtr<Type> structuredBufferType,
677677
RefPtr<Type> elementType);
678678

679-
679+
int findGenericParam(List<RefPtr<GenericParamLayout>> & genericParameters, GlobalGenericParamDecl * decl);
680680
//
681681

682682
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,95 @@
1+
//TEST(compute):COMPARE_RENDER_COMPUTE:-xslang -use-ir
2+
//TEST_INPUT: ubuffer(data=[0 0 0 0], stride=4):dxbinding(0),glbinding(0),out
3+
//TEST_INPUT: type VertImpl
4+
5+
interface IVertInterpolant
6+
{
7+
float4 getColor();
8+
}
9+
10+
__generic_param TVertInterpolant : IVertInterpolant;
11+
12+
struct VertImpl : IVertInterpolant
13+
{
14+
float3 color;
15+
float4 getColor()
16+
{
17+
return float4(1.0);
18+
}
19+
};
20+
21+
RWStructuredBuffer<float> outputBuffer;
22+
23+
cbuffer Uniforms
24+
{
25+
float4x4 modelViewProjection;
26+
}
27+
28+
struct AssembledVertex
29+
{
30+
float3 position;
31+
TVertInterpolant interpolants;
32+
float2 uv;
33+
};
34+
35+
struct CoarseVertex
36+
{
37+
TVertInterpolant interpolants;
38+
float2 uv;
39+
};
40+
41+
struct Fragment
42+
{
43+
float4 color;
44+
};
45+
46+
47+
// Vertex Shader
48+
49+
struct VertexStageInput
50+
{
51+
AssembledVertex assembledVertex : A;
52+
};
53+
54+
struct VertexStageOutput
55+
{
56+
CoarseVertex coarseVertex : CoarseVertex;
57+
float4 sv_position : SV_Position;
58+
};
59+
60+
VertexStageOutput vertexMain(VertexStageInput input)
61+
{
62+
VertexStageOutput output;
63+
64+
float3 position = input.assembledVertex.position;
65+
output.coarseVertex.interpolants = input.assembledVertex.interpolants;
66+
output.sv_position = mul(modelViewProjection, float4(position, 1.0));
67+
output.coarseVertex.uv = input.assembledVertex.uv;
68+
return output;
69+
}
70+
71+
// Fragment Shader
72+
73+
struct FragmentStageInput
74+
{
75+
CoarseVertex coarseVertex : CoarseVertex;
76+
};
77+
78+
struct FragmentStageOutput
79+
{
80+
Fragment fragment : SV_Target;
81+
};
82+
83+
FragmentStageOutput fragmentMain(FragmentStageInput input)
84+
{
85+
FragmentStageOutput output;
86+
87+
float4 color = input.coarseVertex.interpolants.getColor();
88+
float2 uv = input.coarseVertex.uv;
89+
output.fragment.color = color;
90+
outputBuffer[0] = color.x;
91+
outputBuffer[1] = color.y;
92+
outputBuffer[2] = color.z;
93+
outputBuffer[3] = color.w;
94+
return output;
95+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
3F800000
2+
3F800000
3+
3F800000
4+
3F800000

tools/render-test/render-d3d11.cpp

+1-4
Original file line numberDiff line numberDiff line change
@@ -457,10 +457,7 @@ class D3D11Renderer : public Renderer, public ShaderCompiler
457457

458458
UInt RoundUpToAlignment(UInt size, UInt alignment)
459459
{
460-
if (size % alignment)
461-
return (size / alignment + 1) * alignment;
462-
else
463-
return Math::Max(size, alignment);
460+
return ((size + alignment - 1) / alignment) * alignment;
464461
}
465462

466463
virtual Buffer* createBuffer(BufferDesc const& desc) override

tools/render-test/slang-support.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -82,12 +82,12 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
8282
spSetCompileFlags(slangRequest, SLANG_COMPILE_FLAG_NO_CHECKING);
8383
}
8484
ShaderProgram * result = nullptr;
85+
Slang::List<const char*> rawTypeNames;
86+
for (auto typeName : request.entryPointTypeArguments)
87+
rawTypeNames.Add(typeName.Buffer());
8588
if (request.computeShader.name)
8689
{
87-
Slang::List<const char*> rawTypeNames;
88-
for (auto typeName : request.entryPointTypeArguments)
89-
rawTypeNames.Add(typeName.Buffer());
90-
int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit,
90+
int computeEntryPoint = spAddEntryPointEx(slangRequest, computeTranslationUnit,
9191
computeEntryPointName,
9292
spFindProfile(slangSession, request.computeShader.profile),
9393
(int)rawTypeNames.Count(),
@@ -107,8 +107,8 @@ struct SlangShaderCompilerWrapper : public ShaderCompiler
107107
}
108108
else
109109
{
110-
int vertexEntryPoint = spAddEntryPoint(slangRequest, vertexTranslationUnit, vertexEntryPointName, spFindProfile(slangSession, request.vertexShader.profile));
111-
int fragmentEntryPoint = spAddEntryPoint(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, spFindProfile(slangSession, request.fragmentShader.profile));
110+
int vertexEntryPoint = spAddEntryPointEx(slangRequest, vertexTranslationUnit, vertexEntryPointName, spFindProfile(slangSession, request.vertexShader.profile), rawTypeNames.Count(), rawTypeNames.Buffer());
111+
int fragmentEntryPoint = spAddEntryPointEx(slangRequest, fragmentTranslationUnit, fragmentEntryPointName, spFindProfile(slangSession, request.fragmentShader.profile), rawTypeNames.Count(), rawTypeNames.Buffer());
112112

113113
int compileErr = spCompile(slangRequest);
114114
if (auto diagnostics = spGetDiagnosticOutput(slangRequest))

0 commit comments

Comments
 (0)