Skip to content

Commit 27b2229

Browse files
authored
Make precompileForTargets work with Slang API (shader-slang#4845)
* Make precompileForTargets work with Slang API precompileForTargets, renamed to precompileForTarget, does not need an EndToEndCompileRequest and some objects created from it are not necessary either. Take only a target enum and a diagnostic blob as input and handle everything else internally, such as creating the TargetReq with chosen profile. Fixes shader-slang#4790 * Update slang-module.cpp * Update slang-module.cpp
1 parent 99673d7 commit 27b2229

File tree

7 files changed

+86
-67
lines changed

7 files changed

+86
-67
lines changed

include/slang.h

+4
Original file line numberDiff line numberDiff line change
@@ -5444,6 +5444,10 @@ namespace slang
54445444
SlangInt32 index) = 0;
54455445

54465446
virtual SLANG_NO_THROW DeclReflection* SLANG_MCALL getModuleReflection() = 0;
5447+
5448+
virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget(
5449+
SlangCompileTarget target,
5450+
ISlangBlob** outDiagnostics) = 0;
54475451
};
54485452

54495453
#define SLANG_UUID_IModule IModule::getTypeGuid()

source/slang-record-replay/record/slang-module.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -213,6 +213,17 @@ namespace SlangRecord
213213
return res;
214214
}
215215

216+
SLANG_NO_THROW SlangResult ModuleRecorder::precompileForTarget(
217+
SlangCompileTarget target,
218+
ISlangBlob** outDiagnostics)
219+
{
220+
// TODO: We should record this call
221+
// https://github.com/shader-slang/slang/issues/4853
222+
slangRecordLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);
223+
SlangResult res = m_actualModule->precompileForTarget(target, outDiagnostics);
224+
return res;
225+
}
226+
216227
SLANG_NO_THROW slang::ISession* ModuleRecorder::getSession()
217228
{
218229
slangRecordLog(LogLevel::Verbose, "%s\n", __PRETTY_FUNCTION__);

source/slang-record-replay/record/slang-module.h

+3
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,9 @@ namespace SlangRecord
3939
virtual SLANG_NO_THROW SlangInt32 SLANG_MCALL getDependencyFileCount() override;
4040
virtual SLANG_NO_THROW char const* SLANG_MCALL getDependencyFilePath(
4141
SlangInt32 index) override;
42+
virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget(
43+
SlangCompileTarget target,
44+
ISlangBlob** outDiagnostics) override;
4245

4346
// Interfaces for `IComponentType`
4447
virtual SLANG_NO_THROW slang::ISession* SLANG_MCALL getSession() override;

source/slang/slang-compiler-tu.cpp

+28-56
Original file line numberDiff line numberDiff line change
@@ -8,60 +8,24 @@
88

99
namespace Slang
1010
{
11-
SLANG_NO_THROW SlangResult SLANG_MCALL Module::precompileForTargets(
12-
DiagnosticSink* sink,
13-
EndToEndCompileRequest* endToEndReq,
14-
TargetRequest* targetReq)
11+
SLANG_NO_THROW SlangResult SLANG_MCALL Module::precompileForTarget(
12+
SlangCompileTarget target,
13+
slang::IBlob** outDiagnostics)
1514
{
16-
auto module = getIRModule();
17-
Slang::Session* session = endToEndReq->getSession();
18-
Slang::ASTBuilder* astBuilder = session->getGlobalASTBuilder();
19-
Slang::Linkage* builtinLinkage = session->getBuiltinLinkage();
20-
Slang::Linkage linkage(session, astBuilder, builtinLinkage);
21-
22-
CapabilityName precompileRequirement = CapabilityName::Invalid;
23-
switch (targetReq->getTarget())
15+
if (target != SLANG_DXIL)
2416
{
25-
case CodeGenTarget::DXIL:
26-
linkage.addTarget(Slang::CodeGenTarget::DXIL);
27-
precompileRequirement = CapabilityName::dxil_lib;
28-
break;
29-
default:
30-
assert(!"Unhandled target");
31-
break;
17+
return SLANG_FAIL;
3218
}
33-
SLANG_ASSERT(precompileRequirement != CapabilityName::Invalid);
19+
CodeGenTarget targetEnum = CodeGenTarget(target);
3420

35-
// Ensure precompilation capability requirements are met.
36-
auto targetCaps = targetReq->getTargetCaps();
37-
auto precompileRequirementsCapabilitySet = CapabilitySet(precompileRequirement);
38-
if (targetCaps.atLeastOneSetImpliedInOther(precompileRequirementsCapabilitySet) == CapabilitySet::ImpliesReturnFlags::NotImplied)
39-
{
40-
// If `RestrictiveCapabilityCheck` is true we will error, else we will warn.
41-
// error ...: dxil libraries require $0, entry point compiled with $1.
42-
// warn ...: dxil libraries require $0, entry point compiled with $1, implicitly upgrading capabilities.
43-
maybeDiagnoseWarningOrError(
44-
sink,
45-
targetReq->getOptionSet(),
46-
DiagnosticCategory::Capability,
47-
SourceLoc(),
48-
Diagnostics::incompatibleWithPrecompileLib,
49-
Diagnostics::incompatibleWithPrecompileLibRestrictive,
50-
precompileRequirementsCapabilitySet,
51-
targetCaps);
52-
53-
// add precompile requirements to the cooked targetCaps
54-
targetCaps.join(precompileRequirementsCapabilitySet);
55-
if (targetCaps.isInvalid())
56-
{
57-
sink->diagnose(SourceLoc(), Diagnostics::unknownCapability, targetCaps);
58-
return SLANG_FAIL;
59-
}
60-
else
61-
{
62-
targetReq->setTargetCaps(targetCaps);
63-
}
64-
}
21+
auto module = getIRModule();
22+
auto linkage = getLinkage();
23+
24+
DiagnosticSink sink(linkage->getSourceManager(), Lexer::sourceLocationLexer);
25+
applySettingsToDiagnosticSink(&sink, &sink, linkage->m_optionSet);
26+
applySettingsToDiagnosticSink(&sink, &sink, m_optionSet);
27+
28+
TargetRequest* targetReq = new TargetRequest(linkage, targetEnum);
6529

6630
List<RefPtr<ComponentType>> allComponentTypes;
6731
allComponentTypes.add(this); // Add Module as a component type
@@ -72,23 +36,34 @@ namespace Slang
7236
}
7337

7438
auto composite = CompositeComponentType::create(
75-
&linkage,
39+
linkage,
7640
allComponentTypes);
7741

7842
TargetProgram tp(composite, targetReq);
79-
tp.getOrCreateLayout(sink);
43+
tp.getOrCreateLayout(&sink);
8044
Slang::Index const entryPointCount = m_entryPoints.getCount();
45+
tp.getOptionSet().add(CompilerOptionName::GenerateWholeProgram, true);
46+
47+
switch (targetReq->getTarget())
48+
{
49+
case CodeGenTarget::DXIL:
50+
tp.getOptionSet().add(CompilerOptionName::Profile, Profile::RawEnum::DX_Lib_6_6);
51+
break;
52+
}
8153

8254
CodeGenContext::EntryPointIndices entryPointIndices;
8355

8456
entryPointIndices.setCount(entryPointCount);
8557
for (Index i = 0; i < entryPointCount; i++)
8658
entryPointIndices[i] = i;
87-
CodeGenContext::Shared sharedCodeGenContext(&tp, entryPointIndices, sink, endToEndReq);
59+
CodeGenContext::Shared sharedCodeGenContext(&tp, entryPointIndices, &sink, nullptr);
8860
CodeGenContext codeGenContext(&sharedCodeGenContext);
8961

9062
ComPtr<IArtifact> outArtifact;
9163
SlangResult res = codeGenContext.emitTranslationUnit(outArtifact);
64+
65+
sink.getBlobIfNeeded(outDiagnostics);
66+
9267
if (res != SLANG_OK)
9368
{
9469
return res;
@@ -105,9 +80,6 @@ namespace Slang
10580
case CodeGenTarget::DXIL:
10681
builder.emitEmbeddedDXIL(blob);
10782
break;
108-
default:
109-
assert(!"Unhandled target");
110-
break;
11183
}
11284

11385
return SLANG_OK;

source/slang/slang-compiler.h

+3-4
Original file line numberDiff line numberDiff line change
@@ -1482,10 +1482,9 @@ namespace Slang
14821482
SlangInt32 index) override;
14831483

14841484
/// Precompile TU to target language
1485-
virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTargets(
1486-
DiagnosticSink* sink,
1487-
EndToEndCompileRequest* endToEndReq,
1488-
TargetRequest* targetReq);
1485+
virtual SLANG_NO_THROW SlangResult SLANG_MCALL precompileForTarget(
1486+
SlangCompileTarget target,
1487+
slang::IBlob** outDiagnostics) override;
14891488

14901489
virtual void buildHash(DigestBuilder<SHA1>& builder) SLANG_OVERRIDE;
14911490

source/slang/slang.cpp

+4-4
Original file line numberDiff line numberDiff line change
@@ -3252,10 +3252,10 @@ SlangResult EndToEndCompileRequest::executeActionsInner()
32523252

32533253
for (auto translationUnit : frontEndReq->translationUnits)
32543254
{
3255-
translationUnit->getModule()->precompileForTargets(
3256-
getSink(),
3257-
this,
3258-
targetReq);
3255+
SlangCompileTarget target = SlangCompileTarget(targetReq->getTarget());
3256+
translationUnit->getModule()->precompileForTarget(
3257+
target,
3258+
nullptr);
32593259
}
32603260
}
32613261
}

tools/gfx-unit-test/precompiled-module-2.cpp

+33-3
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ namespace gfx_test
1717
static Slang::Result precompileProgram(
1818
gfx::IDevice* device,
1919
ISlangMutableFileSystem* fileSys,
20-
const char* shaderModuleName)
20+
const char* shaderModuleName,
21+
bool precompileToTarget)
2122
{
2223
Slang::ComPtr<slang::ISession> slangSession;
2324
SLANG_RETURN_ON_FAIL(device->getSlangSession(slangSession.writeRef()));
@@ -34,6 +35,20 @@ namespace gfx_test
3435
if (!module)
3536
return SLANG_FAIL;
3637

38+
if (precompileToTarget)
39+
{
40+
SlangCompileTarget target;
41+
switch (device->getDeviceInfo().deviceType)
42+
{
43+
case gfx::DeviceType::DirectX12:
44+
target = SLANG_DXIL;
45+
break;
46+
default:
47+
return SLANG_FAIL;
48+
}
49+
module->precompileForTarget(target, diagnosticsBlob.writeRef());
50+
}
51+
3752
// Write loaded modules to memory file system.
3853
for (SlangInt i = 0; i < slangSession->getLoadedModuleCount(); i++)
3954
{
@@ -50,7 +65,7 @@ namespace gfx_test
5065
return SLANG_OK;
5166
}
5267

53-
void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context)
68+
void precompiledModule2TestImplCommon(IDevice* device, UnitTestContext* context, bool precompileToTarget)
5469
{
5570
Slang::ComPtr<ITransientResourceHeap> transientHeap;
5671
ITransientResourceHeap::Desc transientHeapDesc = {};
@@ -63,7 +78,7 @@ namespace gfx_test
6378

6479
ComPtr<IShaderProgram> shaderProgram;
6580
slang::ProgramLayout* slangReflection;
66-
GFX_CHECK_CALL_ABORT(precompileProgram(device, memoryFileSystem.get(), "precompiled-module-imported"));
81+
GFX_CHECK_CALL_ABORT(precompileProgram(device, memoryFileSystem.get(), "precompiled-module-imported", precompileToTarget));
6782

6883
// Next, load the precompiled slang program.
6984
Slang::ComPtr<slang::ISession> slangSession;
@@ -168,11 +183,26 @@ namespace gfx_test
168183
Slang::makeArray<float>(3.0f, 3.0f, 3.0f, 3.0f));
169184
}
170185

186+
void precompiledModule2TestImpl(IDevice* device, UnitTestContext* context)
187+
{
188+
precompiledModule2TestImplCommon(device, context, false);
189+
}
190+
191+
void precompiledTargetModule2TestImpl(IDevice* device, UnitTestContext* context)
192+
{
193+
precompiledModule2TestImplCommon(device, context, true);
194+
}
195+
171196
SLANG_UNIT_TEST(precompiledModule2D3D12)
172197
{
173198
runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
174199
}
175200

201+
SLANG_UNIT_TEST(precompiledTargetModule2D3D12)
202+
{
203+
runTestImpl(precompiledTargetModule2TestImpl, unitTestContext, Slang::RenderApiFlag::D3D12);
204+
}
205+
176206
SLANG_UNIT_TEST(precompiledModule2Vulkan)
177207
{
178208
runTestImpl(precompiledModule2TestImpl, unitTestContext, Slang::RenderApiFlag::Vulkan);

0 commit comments

Comments
 (0)