Skip to content

Commit 5248a02

Browse files
authored
Fix codegen bug when targeting PTX with new API (#6506)
* Add cuda codegen bug repro This just compiles tests/compute/simlpe.slang for PTX with the new compilation API, in order to reproduce a code generation bug. * Detect entrypoint more robustly when applying ConstRef hack during lowring For shaders like tests/compute/simple.slang, which have a 'numthreads' attribute but no 'shader' attribute, the old compile request API would add an EntryPointAttribute to the AST node of the entry point. However, the new API doesn't, and so a certain ConstRef hack doesn't get applied when using the new API, leading to subsequent code generation issues. This patch also checks for a 'numthreads' attribute when deciding whether to apply the ConstRef hack. This closes issue #6507 and helps to resolve issue #4760. * Add expected failure list for GitHub runners Our GitHub runners don't have the CUDA toolkits installed, so they can't run all tests.
1 parent 6f56b47 commit 5248a02

File tree

4 files changed

+72
-3
lines changed

4 files changed

+72
-3
lines changed

.github/workflows/ci.yml

+4-2
Original file line numberDiff line numberDiff line change
@@ -176,14 +176,16 @@ jobs:
176176
-category ${{ matrix.test-category }} \
177177
-api all-dx12 \
178178
-expected-failure-list tests/expected-failure-github.txt \
179-
-expected-failure-list tests/expected-failure-record-replay-tests.txt
179+
-expected-failure-list tests/expected-failure-record-replay-tests.txt \
180+
-expected-failure-list tests/expected-failure-github-runner.txt
180181
else
181182
"$bin_dir/slang-test" \
182183
-use-test-server \
183184
-category ${{ matrix.test-category }} \
184185
-api all-dx12 \
185186
-expected-failure-list tests/expected-failure-github.txt \
186-
-expected-failure-list tests/expected-failure-record-replay-tests.txt
187+
-expected-failure-list tests/expected-failure-record-replay-tests.txt \
188+
-expected-failure-list tests/expected-failure-github-runner.txt
187189
fi
188190
- name: Run Slang examples
189191
if: steps.filter.outputs.should-run == 'true' && matrix.platform != 'wasm' && matrix.full-gpu-tests

source/slang/slang-lower-to-ir.cpp

+2-1
Original file line numberDiff line numberDiff line change
@@ -3214,7 +3214,8 @@ void collectParameterLists(
32143214
// For now we will rely on a follow up pass to remove unnecessary temporary variables if
32153215
// we can determine that they are never actually writtten to by the user.
32163216
//
3217-
bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>();
3217+
bool lowerVaryingInputAsConstRef = declRef.getDecl()->hasModifier<EntryPointAttribute>() ||
3218+
declRef.getDecl()->hasModifier<NumThreadsAttribute>();
32183219

32193220
// Don't collect parameters from the outer scope if
32203221
// we are in a `static` context.
+1
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
slang-unit-test-tool/cudaCodeGenBug.internal

tools/slang-unit-test/unit-test-find-check-entrypoint.cpp

+65
Original file line numberDiff line numberDiff line change
@@ -71,3 +71,68 @@ SLANG_UNIT_TEST(findAndCheckEntryPoint)
7171
SLANG_CHECK(code != nullptr);
7272
SLANG_CHECK(code->getBufferSize() != 0);
7373
}
74+
75+
// This test reproduces issue #6507, where it was noticed that compilation of
76+
// tests/compute/simple.slang for PTX target generates invalid code.
77+
// TODO: Remove this when issue #4760 is resolved, because at that point
78+
// tests/compute/simple.slang should cover the same issue.
79+
SLANG_UNIT_TEST(cudaCodeGenBug)
80+
{
81+
// Source for a module that contains an undecorated entrypoint.
82+
const char* userSourceBody = R"(
83+
RWStructuredBuffer<float> outputBuffer;
84+
85+
[numthreads(4, 1, 1)]
86+
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
87+
{
88+
outputBuffer[dispatchThreadID.x] = float(dispatchThreadID.x);
89+
}
90+
)";
91+
92+
auto moduleName = "moduleG" + String(Process::getId());
93+
String userSource = "import " + moduleName + ";\n" + userSourceBody;
94+
ComPtr<slang::IGlobalSession> globalSession;
95+
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
96+
slang::TargetDesc targetDesc = {};
97+
targetDesc.format = SLANG_PTX;
98+
slang::SessionDesc sessionDesc = {};
99+
sessionDesc.targetCount = 1;
100+
sessionDesc.targets = &targetDesc;
101+
ComPtr<slang::ISession> session;
102+
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
103+
104+
ComPtr<slang::IBlob> diagnosticBlob;
105+
auto module = session->loadModuleFromSourceString(
106+
"m",
107+
"m.slang",
108+
userSourceBody,
109+
diagnosticBlob.writeRef());
110+
SLANG_CHECK(module != nullptr);
111+
112+
ComPtr<slang::IEntryPoint> entryPoint;
113+
module->findAndCheckEntryPoint(
114+
"computeMain",
115+
SLANG_STAGE_COMPUTE,
116+
entryPoint.writeRef(),
117+
diagnosticBlob.writeRef());
118+
SLANG_CHECK(entryPoint != nullptr);
119+
120+
ComPtr<slang::IComponentType> compositeProgram;
121+
slang::IComponentType* components[] = {module, entryPoint.get()};
122+
session->createCompositeComponentType(
123+
components,
124+
2,
125+
compositeProgram.writeRef(),
126+
diagnosticBlob.writeRef());
127+
SLANG_CHECK(compositeProgram != nullptr);
128+
129+
ComPtr<slang::IComponentType> linkedProgram;
130+
compositeProgram->link(linkedProgram.writeRef(), diagnosticBlob.writeRef());
131+
SLANG_CHECK(linkedProgram != nullptr);
132+
133+
ComPtr<slang::IBlob> code;
134+
auto res = linkedProgram->getEntryPointCode(0, 0, code.writeRef(), diagnosticBlob.writeRef());
135+
SLANG_CHECK(res == SLANG_OK);
136+
SLANG_CHECK(code != nullptr);
137+
SLANG_CHECK(code->getBufferSize() != 0);
138+
}

0 commit comments

Comments
 (0)