Skip to content

Commit 94d3f2b

Browse files
authored
Add API for whole program compilation. (shader-slang#1562)
* Add API for whole program compilation. This change exposes a new target flag: `SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM` that can be set on a target with `spSetTargetFlags`. When this flag is set, `spCompile` function generates target code for the entire input module instead of just the specified entrypoints. The resulting code will include all the entrypoints defined in the input source. The resulting whole program code can be retrieved with two new functions: `spGetTargetCodeBlob` and `spGetTargetHostCallable`. This change also cleans up the unnecessary `entryPointIndices` parameter of `TargetProgram::getOrCreateWholeProgramResult`, and modifies the `cpu-hello-world` example to make use of the new whole-program compilation API to simplify its logic. * Update comments.
1 parent b72353e commit 94d3f2b

File tree

7 files changed

+122
-37
lines changed

7 files changed

+122
-37
lines changed

examples/cpu-hello-world/main.cpp

+6-10
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,10 @@ static SlangResult _innerMain(int argc, char** argv)
8585
// If we wanted a just a shared library/dll, we could have used SLANG_SHARED_LIBRARY.
8686
int targetIndex = spAddCodeGenTarget(slangRequest, SLANG_HOST_CALLABLE);
8787

88+
// Set the target flag to indicate that we want to compile all the entrypoints in the
89+
// slang shader file into a library.
90+
spSetTargetFlags(slangRequest, targetIndex, SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM);
91+
8892
// A compile request can include one or more "translation units," which more or
8993
// less amount to individual source files (think `.c` files, not the `.h` files they
9094
// might include).
@@ -99,14 +103,6 @@ static SlangResult _innerMain(int argc, char** argv)
99103
// There are also variations of this API for adding source code from application-provided buffers.
100104
//
101105
spAddTranslationUnitSourceFile(slangRequest, translationUnitIndex, "shader.slang");
102-
103-
// Next we will specify the entry points we'd like to compile.
104-
// It is often convenient to put more than one entry point in the same file,
105-
// and the Slang API makes it convenient to use a single run of the compiler
106-
// to compile all entry points.
107-
//
108-
const char entryPointName[] = "computeMain";
109-
int computeIndex = spAddEntryPoint(slangRequest, translationUnitIndex, entryPointName, SLANG_STAGE_COMPUTE);
110106

111107
// Once all of the input options for the compiler have been specified,
112108
// we can invoke `spCompile` to run the compiler and see if any errors
@@ -133,18 +129,18 @@ static SlangResult _innerMain(int argc, char** argv)
133129
// Get the 'shared library' (note that this doesn't necessarily have to be implemented as a shared library
134130
// it's just an interface to executable code).
135131
ComPtr<ISlangSharedLibrary> sharedLibrary;
136-
SLANG_RETURN_ON_FAIL(spGetEntryPointHostCallable(slangRequest, 0, 0, sharedLibrary.writeRef()));
132+
SLANG_RETURN_ON_FAIL(spGetTargetHostCallable(slangRequest, 0, sharedLibrary.writeRef()));
137133

138134
// Once we have the sharedLibrary, we no longer need the request
139135
// unless we want to use reflection, to for example workout how 'UniformState' and 'UniformEntryPointParams' are laid out
140136
// at runtime. We don't do that here - as we hard code the structures.
141137
spDestroyCompileRequest(slangRequest);
142138

143139
// Get the function we are going to execute
140+
const char entryPointName[] = "computeMain";
144141
CPPPrelude::ComputeFunc func = (CPPPrelude::ComputeFunc)sharedLibrary->findFuncByName(entryPointName);
145142
if (!func)
146143
{
147-
spDestroyCompileRequest(slangRequest);
148144
return SLANG_FAIL;
149145
}
150146

examples/cpu-hello-world/shader.slang

+1
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
//TEST_INPUT:ubuffer(random(float, 4096, -1.0, 1.0), stride=4):name=ioBuffer
44
RWStructuredBuffer<float> ioBuffer;
55

6+
[shader("compute")]
67
[numthreads(4, 1, 1)]
78
void computeMain(uint3 dispatchThreadID : SV_DispatchThreadID)
89
{

slang.h

+34-2
Original file line numberDiff line numberDiff line change
@@ -584,6 +584,12 @@ extern "C"
584584
@deprecated This behavior is now enabled unconditionally.
585585
*/
586586
SLANG_TARGET_FLAG_PARAMETER_BLOCKS_USE_REGISTER_SPACES = 1 << 4,
587+
588+
/* When set, will generate target code that contains all entrypoints defined
589+
in the input source or specified via the `spAddEntryPoint` function in a
590+
single output module (library/source file).
591+
*/
592+
SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM = 1 << 8
587593
};
588594

589595
/*!
@@ -1632,8 +1638,6 @@ extern "C"
16321638
@param targetIndex The index of the target to get code for (default: zero).
16331639
@param outBlob A pointer that will receive the blob of code
16341640
@returns A `SlangResult` to indicate success or failure.
1635-
1636-
The lifetime of the output pointer is the same as `request`.
16371641
*/
16381642
SLANG_API SlangResult spGetEntryPointCodeBlob(
16391643
SlangCompileRequest* request,
@@ -1659,6 +1663,34 @@ extern "C"
16591663
int targetIndex,
16601664
ISlangSharedLibrary** outSharedLibrary);
16611665

1666+
/** Get the output code associated with a specific target.
1667+
1668+
@param request The request
1669+
@param targetIndex The index of the target to get code for (default: zero).
1670+
@param outBlob A pointer that will receive the blob of code
1671+
@returns A `SlangResult` to indicate success or failure.
1672+
*/
1673+
SLANG_API SlangResult spGetTargetCodeBlob(
1674+
SlangCompileRequest* request,
1675+
int targetIndex,
1676+
ISlangBlob** outBlob);
1677+
1678+
/** Get 'callable' functions for a target accessible through the ISlangSharedLibrary interface.
1679+
1680+
That the functions remain in scope as long as the ISlangSharedLibrary interface is in scope.
1681+
1682+
NOTE! Requires a compilation target of SLANG_HOST_CALLABLE.
1683+
1684+
@param request The request
1685+
@param targetIndex The index of the target to get code for (default: zero).
1686+
@param outSharedLibrary A pointer to a ISharedLibrary interface which functions can be queried on.
1687+
@returns A `SlangResult` to indicate success or failure.
1688+
*/
1689+
SLANG_API SlangResult spGetTargetHostCallable(
1690+
SlangCompileRequest* request,
1691+
int targetIndex,
1692+
ISlangSharedLibrary** outSharedLibrary);
1693+
16621694
/** Get the output bytecode associated with an entire compile request.
16631695
16641696
The lifetime of the output pointer is the same as `request` and the last spCompile.

source/slang/slang-compiler.cpp

+11-20
Original file line numberDiff line numberDiff line change
@@ -2197,24 +2197,19 @@ SlangResult dissassembleDXILUsingDXC(
21972197
}
21982198

21992199
CompileResult& TargetProgram::_createWholeProgramResult(
2200-
const List<Int>& entryPointIndices,
22012200
BackEndCompileRequest* backEndRequest,
22022201
EndToEndCompileRequest* endToEndRequest)
22032202
{
2204-
for (auto entryPointIndex = entryPointIndices.begin(); entryPointIndex != entryPointIndices.end(); entryPointIndex++) {
2205-
if (*entryPointIndex >= m_entryPointResults.getCount())
2206-
m_entryPointResults.setCount(*entryPointIndex + 1);
2203+
// We want to call `emitEntryPoints` function to generate code that contains
2204+
// all the entrypoints defined in `m_program`.
2205+
// The current logic of `emitEntryPoints` takes a list of entry-point indices to
2206+
// emit code for, so we construct such a list first.
2207+
List<Int> entryPointIndices;
2208+
m_entryPointResults.setCount(m_program->getEntryPointCount());
2209+
entryPointIndices.setCount(m_program->getEntryPointCount());
2210+
for (Index i = 0; i < entryPointIndices.getCount(); i++)
2211+
entryPointIndices[i] = i;
22072212

2208-
// It is possible that entry points goot added to the `Program`
2209-
// *after* we created this `TargetProgram`, so there might be
2210-
// a request for an entry point that we didn't allocate space for.
2211-
//
2212-
// TODO: Change the construction logic so that a `Program` is
2213-
// constructed all at once rather than incrementally, to avoid
2214-
// this problem.
2215-
//
2216-
//auto entryPoint = m_program->getEntryPoint(*entryPointIndex);
2217-
}
22182213
auto& result = m_wholeProgramResult;
22192214
result = emitEntryPoints(
22202215
m_program,
@@ -2224,7 +2219,6 @@ SlangResult dissassembleDXILUsingDXC(
22242219
endToEndRequest);
22252220

22262221
return result;
2227-
22282222
}
22292223

22302224
CompileResult& TargetProgram::_createEntryPointResult(
@@ -2256,7 +2250,6 @@ SlangResult dissassembleDXILUsingDXC(
22562250
}
22572251

22582252
CompileResult& TargetProgram::getOrCreateWholeProgramResult(
2259-
const List<Int>& entryPointIndices,
22602253
DiagnosticSink* sink)
22612254
{
22622255
auto& result = m_wholeProgramResult;
@@ -2278,7 +2271,6 @@ SlangResult dissassembleDXILUsingDXC(
22782271
m_program);
22792272

22802273
return _createWholeProgramResult(
2281-
entryPointIndices,
22822274
backEndRequest,
22832275
nullptr);
22842276
}
@@ -2325,10 +2317,9 @@ SlangResult dissassembleDXILUsingDXC(
23252317
// Generate target code any entry points that
23262318
// have been requested for compilation.
23272319
auto entryPointCount = program->getEntryPointCount();
2328-
if (targetReq->isWholeProgramRequest)
2320+
if (targetReq->isWholeProgramRequest())
23292321
{
23302322
targetProgram->_createWholeProgramResult(
2331-
List<Int>(),
23322323
compileReq,
23332324
endToEndReq);
23342325
}
@@ -2497,7 +2488,7 @@ SlangResult dissassembleDXILUsingDXC(
24972488
for (auto targetReq : linkage->targets)
24982489
{
24992490
Index entryPointCount = program->getEntryPointCount();
2500-
if (targetReq->isWholeProgramRequest) {
2491+
if (targetReq->isWholeProgramRequest()) {
25012492
writeWholeProgramResult(
25022493
compileRequest,
25032494
targetReq);

source/slang/slang-compiler.h

+5-3
Original file line numberDiff line numberDiff line change
@@ -1132,7 +1132,10 @@ namespace Slang
11321132
SlangTargetFlags targetFlags = 0;
11331133
Slang::Profile targetProfile = Slang::Profile();
11341134
FloatingPointMode floatingPointMode = FloatingPointMode::Default;
1135-
bool isWholeProgramRequest = false;
1135+
bool isWholeProgramRequest()
1136+
{
1137+
return (targetFlags & SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM) != 0;
1138+
}
11361139

11371140
Linkage* getLinkage() { return linkage; }
11381141
CodeGenTarget getTarget() { return target; }
@@ -1673,7 +1676,7 @@ namespace Slang
16731676
/// code generation to the given `sink`.
16741677
///
16751678
CompileResult& getOrCreateEntryPointResult(Int entryPointIndex, DiagnosticSink* sink);
1676-
CompileResult& getOrCreateWholeProgramResult(const List<Int>& entryPointIndices, DiagnosticSink* sink);
1679+
CompileResult& getOrCreateWholeProgramResult(DiagnosticSink* sink);
16771680

16781681

16791682
CompileResult& getExistingWholeProgramResult()
@@ -1691,7 +1694,6 @@ namespace Slang
16911694
}
16921695

16931696
CompileResult& _createWholeProgramResult(
1694-
const List<Int>& entryPointIndices,
16951697
BackEndCompileRequest* backEndRequest,
16961698
EndToEndCompileRequest* endToEndRequest);
16971699
/// Internal helper for `getOrCreateEntryPointResult`.

source/slang/slang-options.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -160,7 +160,6 @@ struct OptionsParser
160160
SlangTargetFlags targetFlags = 0;
161161
int targetID = -1;
162162
FloatingPointMode floatingPointMode = FloatingPointMode::Default;
163-
bool isWholeProgramRequest = false;
164163

165164
// State for tracking command-line errors
166165
bool conflictingProfilesSet = false;
@@ -1511,7 +1510,7 @@ struct OptionsParser
15111510
}
15121511
else
15131512
{
1514-
target->isWholeProgramRequest = true;
1513+
target->targetFlags |= SLANG_TARGET_FLAG_GENERATE_WHOLE_PROGRAM;
15151514
targetInfo->wholeTargetOutputPath = rawOutput.path;
15161515
}
15171516
}

source/slang/slang.cpp

+64
Original file line numberDiff line numberDiff line change
@@ -3613,6 +3613,33 @@ static SlangResult _getEntryPointResult(
36133613
return SLANG_OK;
36143614
}
36153615

3616+
static SlangResult _getWholeProgramResult(
3617+
SlangCompileRequest* request,
3618+
int targetIndex,
3619+
Slang::CompileResult** outCompileResult)
3620+
{
3621+
using namespace Slang;
3622+
if (!request)
3623+
return SLANG_ERROR_INVALID_PARAMETER;
3624+
3625+
auto req = Slang::asInternal(request);
3626+
auto linkage = req->getLinkage();
3627+
auto program = req->getSpecializedGlobalAndEntryPointsComponentType();
3628+
3629+
Index targetCount = linkage->targets.getCount();
3630+
if ((targetIndex < 0) || (targetIndex >= targetCount))
3631+
{
3632+
return SLANG_ERROR_INVALID_PARAMETER;
3633+
}
3634+
auto targetReq = linkage->targets[targetIndex];
3635+
3636+
auto targetProgram = program->getTargetProgram(targetReq);
3637+
if (!targetProgram)
3638+
return SLANG_FAIL;
3639+
*outCompileResult = &targetProgram->getExistingWholeProgramResult();
3640+
return SLANG_OK;
3641+
}
3642+
36163643
SLANG_API SlangResult spGetEntryPointCodeBlob(
36173644
SlangCompileRequest* request,
36183645
int entryPointIndex,
@@ -3648,6 +3675,43 @@ SLANG_API SlangResult spGetEntryPointHostCallable(
36483675
return SLANG_OK;
36493676
}
36503677

3678+
SLANG_API SlangResult spGetTargetCodeBlob(
3679+
SlangCompileRequest* request,
3680+
int targetIndex,
3681+
ISlangBlob** outBlob)
3682+
{
3683+
using namespace Slang;
3684+
if (!outBlob)
3685+
return SLANG_ERROR_INVALID_PARAMETER;
3686+
Slang::CompileResult* compileResult = nullptr;
3687+
SLANG_RETURN_ON_FAIL(
3688+
_getWholeProgramResult(request, targetIndex, &compileResult));
3689+
3690+
ComPtr<ISlangBlob> blob;
3691+
SLANG_RETURN_ON_FAIL(compileResult->getBlob(blob));
3692+
*outBlob = blob.detach();
3693+
return SLANG_OK;
3694+
}
3695+
3696+
SLANG_API SlangResult spGetTargetHostCallable(
3697+
SlangCompileRequest* request,
3698+
int targetIndex,
3699+
ISlangSharedLibrary** outSharedLibrary)
3700+
{
3701+
using namespace Slang;
3702+
if (!outSharedLibrary)
3703+
return SLANG_ERROR_INVALID_PARAMETER;
3704+
3705+
Slang::CompileResult* compileResult = nullptr;
3706+
SLANG_RETURN_ON_FAIL(
3707+
_getWholeProgramResult(request, targetIndex, &compileResult));
3708+
3709+
ComPtr<ISlangSharedLibrary> sharedLibrary;
3710+
SLANG_RETURN_ON_FAIL(compileResult->getSharedLibrary(sharedLibrary));
3711+
*outSharedLibrary = sharedLibrary.detach();
3712+
return SLANG_OK;
3713+
}
3714+
36513715
SLANG_API char const* spGetEntryPointSource(
36523716
SlangCompileRequest* request,
36533717
int entryPointIndex)

0 commit comments

Comments
 (0)