Skip to content

Commit cd27fbd

Browse files
csyonghealeino-nv
andauthored
Add a unit test to cover type-conformance compilation API. (shader-slang#6178)
Co-authored-by: Anders Leino <aleino@nvidia.com>
1 parent 31bb5ea commit cd27fbd

File tree

2 files changed

+173
-0
lines changed

2 files changed

+173
-0
lines changed

source/slang/slang.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -1337,6 +1337,7 @@ void Linkage::addTarget(slang::TargetDesc const& desc)
13371337
optionSet.setProfile(Profile(desc.profile));
13381338
optionSet.set(CompilerOptionName::LineDirectiveMode, LineDirectiveMode(desc.lineDirectiveMode));
13391339
optionSet.set(CompilerOptionName::GLSLForceScalarLayout, desc.forceGLSLScalarBufferLayout);
1340+
optionSet.load(desc.compilerOptionEntryCount, desc.compilerOptionEntries);
13401341
}
13411342

13421343
#if 0
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,172 @@
1+
// unit-test-type-conformance.cpp
2+
3+
#include "../../source/core/slang-io.h"
4+
#include "../../source/core/slang-process.h"
5+
#include "slang-com-ptr.h"
6+
#include "slang.h"
7+
#include "unit-test/slang-unit-test.h"
8+
9+
#include <stdio.h>
10+
#include <stdlib.h>
11+
12+
using namespace Slang;
13+
14+
// Test the compilation API for adding type conformances.
15+
16+
SLANG_UNIT_TEST(typeConformance)
17+
{
18+
const char* userSourceBody = R"(
19+
struct SurfaceInteraction {
20+
};
21+
22+
__generic<T>
23+
struct InterfacePtr {
24+
T *dptr;
25+
};
26+
27+
struct BsdfSample {
28+
float3 wo;
29+
float pdf;
30+
bool delta;
31+
float3 spectrum;
32+
};
33+
interface IBsdf {
34+
35+
BsdfSample sample(SurfaceInteraction si, float2 uv);
36+
};
37+
struct Diffuse : IBsdf {
38+
float3 _reflectance;
39+
40+
BsdfSample sample(SurfaceInteraction si, float2 uv) {
41+
BsdfSample sample;
42+
sample.wo = float3(uv, 1.0f);
43+
sample.pdf = uv.x;
44+
sample.delta = false;
45+
sample.spectrum = _reflectance;
46+
return sample;
47+
}
48+
};
49+
50+
interface IShape {
51+
property InterfacePtr<IBsdf> bsdf;
52+
};
53+
struct Mesh : IShape {
54+
InterfacePtr<IBsdf> bsdf;
55+
};
56+
struct Sphere : IShape {
57+
InterfacePtr<IBsdf> bsdf;
58+
};
59+
60+
[[vk::push_constant]] IShape *shapes;
61+
struct Path {
62+
float3 sample(IShape *shapes) {
63+
float3 spectrum = { 0.0f, 0.0f, 0.0f };
64+
float3 throughput = { 1.0f, 1.0f, 1.0f };
65+
66+
while (true) {
67+
SurfaceInteraction si = {};
68+
69+
if (true) {
70+
const float p = min(max(throughput.r, max(throughput.g, throughput.b)), 0.95f);
71+
if (1.0f >= p) return spectrum;
72+
}
73+
74+
BsdfSample sample = shapes[0].bsdf.dptr.sample(si, float2(1.0f));
75+
throughput *= sample.spectrum;
76+
}
77+
return spectrum;
78+
}
79+
};
80+
81+
[[vk::binding(0, 0)]] RWTexture2D<float4> output;
82+
83+
[shader("compute"), numthreads(1, 1, 1)]
84+
void computeMain() {
85+
Path path = Path();
86+
float3 spectrum = path.sample(nullptr);
87+
output[uint2(0,0)] += float4(spectrum, 1.0f);
88+
}
89+
)";
90+
ComPtr<slang::IGlobalSession> globalSession;
91+
SlangGlobalSessionDesc globalDesc = {};
92+
globalDesc.enableGLSL = true;
93+
SLANG_CHECK(slang_createGlobalSession2(&globalDesc, globalSession.writeRef()) == SLANG_OK);
94+
slang::TargetDesc targetDesc = {};
95+
targetDesc.format = SLANG_SPIRV;
96+
targetDesc.profile = globalSession->findProfile("spirv_1_5");
97+
targetDesc.compilerOptionEntryCount = 1;
98+
slang::CompilerOptionEntry entry;
99+
entry.name = slang::CompilerOptionName::Optimization;
100+
entry.value.kind = slang::CompilerOptionValueKind::Int;
101+
entry.value.intValue0 = 0;
102+
targetDesc.compilerOptionEntries = &entry;
103+
104+
slang::SessionDesc sessionDesc = {};
105+
sessionDesc.targetCount = 1;
106+
sessionDesc.targets = &targetDesc;
107+
108+
ComPtr<slang::ISession> session;
109+
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
110+
111+
ComPtr<slang::IBlob> diagnosticBlob;
112+
auto module = session->loadModuleFromSourceString(
113+
"m",
114+
"m.slang",
115+
userSourceBody,
116+
diagnosticBlob.writeRef());
117+
SLANG_CHECK(module != nullptr);
118+
119+
ComPtr<slang::IEntryPoint> entryPoint;
120+
module->findAndCheckEntryPoint(
121+
"computeMain",
122+
SLANG_STAGE_COMPUTE,
123+
entryPoint.writeRef(),
124+
diagnosticBlob.writeRef());
125+
126+
auto layout = module->getLayout();
127+
128+
auto diffuse = layout->findTypeByName("Diffuse");
129+
auto ibsdf = layout->findTypeByName("IBsdf");
130+
auto ishape = layout->findTypeByName("IShape");
131+
auto mesh = layout->findTypeByName("Mesh");
132+
auto sphere = layout->findTypeByName("Sphere");
133+
134+
ComPtr<slang::ITypeConformance> diffuseIBsdf;
135+
ComPtr<slang::ITypeConformance> meshIShape;
136+
ComPtr<slang::ITypeConformance> sphereIShape;
137+
session->createTypeConformanceComponentType(
138+
diffuse,
139+
ibsdf,
140+
diffuseIBsdf.writeRef(),
141+
0,
142+
diagnosticBlob.writeRef());
143+
session->createTypeConformanceComponentType(
144+
mesh,
145+
ishape,
146+
meshIShape.writeRef(),
147+
0,
148+
diagnosticBlob.writeRef());
149+
session->createTypeConformanceComponentType(
150+
sphere,
151+
ishape,
152+
sphereIShape.writeRef(),
153+
0,
154+
diagnosticBlob.writeRef());
155+
156+
slang::IComponentType* componentTypes[5] =
157+
{module, entryPoint.get(), diffuseIBsdf, meshIShape, sphereIShape};
158+
ComPtr<slang::IComponentType> composedProgram;
159+
session->createCompositeComponentType(
160+
componentTypes,
161+
5,
162+
composedProgram.writeRef(),
163+
diagnosticBlob.writeRef());
164+
165+
ComPtr<slang::IComponentType> linkedProgram;
166+
composedProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
167+
168+
ComPtr<slang::IBlob> code;
169+
linkedProgram->getTargetCode(0, code.writeRef(), diagnosticBlob.writeRef());
170+
171+
SLANG_CHECK(code != nullptr);
172+
}

0 commit comments

Comments
 (0)