Skip to content

Commit 31548e1

Browse files
committed
Hack unit test to see the generated CUDA code
1 parent 182e5c6 commit 31548e1

File tree

1 file changed

+24
-8
lines changed

1 file changed

+24
-8
lines changed

tools/slang-unit-test/unit-test-find-check-entrypoint.cpp

+24-8
Original file line numberDiff line numberDiff line change
@@ -18,23 +18,37 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
1818
{
1919
// Source for a module that contains an undecorated entrypoint.
2020
const char* userSourceBody = R"(
21-
float4 fragMain(float4 pos:SV_Position) : SV_Target
22-
{
23-
return pos;
24-
}
21+
RWStructuredBuffer<float> outputBuffer;
22+
23+
[numthreads(4, 1, 1)]
24+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
25+
{
26+
outputBuffer[dispatchThreadID.x] = float(dispatchThreadID.x);
27+
}
2528
)";
2629

2730
auto moduleName = "moduleG" + String(Process::getId());
2831
String userSource = "import " + moduleName + ";\n" + userSourceBody;
2932
ComPtr<slang::IGlobalSession> globalSession;
3033
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
3134
slang::TargetDesc targetDesc = {};
32-
targetDesc.format = SLANG_CUDA_SOURCE;
33-
targetDesc.profile = globalSession->findProfile("spirv_1_5");
35+
targetDesc.format = SLANG_PTX;
3436
slang::SessionDesc sessionDesc = {};
3537
sessionDesc.targetCount = 1;
3638
sessionDesc.targets = &targetDesc;
3739
ComPtr<slang::ISession> session;
40+
List<slang::CompilerOptionEntry> sessionOptionEntries;
41+
{
42+
slang::CompilerOptionEntry entry = {};
43+
entry.name = slang::CompilerOptionName::DumpIntermediates;
44+
entry.value.kind = slang::CompilerOptionValueKind::Int;
45+
entry.value.intValue0 = int(false);
46+
sessionOptionEntries.add(entry);
47+
}
48+
sessionDesc.compilerOptionEntries = sessionOptionEntries.getBuffer();
49+
sessionDesc.compilerOptionEntryCount= sessionOptionEntries.getCount();
50+
targetDesc.compilerOptionEntries = sessionOptionEntries.getBuffer();
51+
targetDesc.compilerOptionEntryCount = sessionOptionEntries.getCount();
3852
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
3953

4054
ComPtr<slang::IBlob> diagnosticBlob;
@@ -47,7 +61,7 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
4761

4862
ComPtr<slang::IEntryPoint> entryPoint;
4963
module->findAndCheckEntryPoint(
50-
"fragMain",
64+
"computeMain",
5165
SLANG_STAGE_COMPUTE,
5266
entryPoint.writeRef(),
5367
diagnosticBlob.writeRef());
@@ -67,7 +81,9 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
6781
SLANG_CHECK(linkedProgram != nullptr);
6882

6983
ComPtr<slang::IBlob> code;
70-
linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
84+
auto res = linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
85+
if (res != SLANG_OK)
86+
std::cout << "diagnostic: " << (char*)diagnosticBlob->getBufferPointer() << std::endl;
7187
SLANG_CHECK(code != nullptr);
7288
SLANG_CHECK(code->getBufferSize() != 0);
7389

0 commit comments

Comments
 (0)