Skip to content

Commit 039c233

Browse files
author
Tim Foley
authored
Add callable shader support for Vulkan ray tracing (shader-slang#718)
* Add callable shader support for Vulkan ray tracing This change extends the previous work to update Vulkan ray tracing support for the finished `GL_NV_ray_tracing` spec. One of the features missing in the experimental extension that was added to the final spec is "callable shaders," which allow ray tracing shaders to call other shaders as general-purpose subroutines. Most of the implementation work here mirrors what was done for the `TraceRay()` function to map it to `traceNV()`. We map the generic `CallShader<P>` function to the non-generic `executeCallableNV`, with a payload identifier that indicates a specific global variable of type `P` (the global variable being generated from a `static` local in `CallShader`). A new modifier is added to identify the payload structure, and the parameter binding/layout logic introduces a new resource kind for callable-shader payload data (where previously the logic had assumed ray and callable payloads should use the same resource kind). Two test shaders are included: one for the callable shader (`callable.slang`) and one for a ray generation shader that calls it (`callable-caller.slang`). Just for kicks, the payload data type is defined in a shared file so that we can be sure the two agree (trying to emulate what might be good practice, and ensure that ray tracing support works together with other Slang mechanisms). * Typo fix: assocaited->associated One instance was found in review, but I went ahead and fixed a bunch since I seem to make this typo a lot. * Typo fix: defintiion->definition
1 parent c07f60a commit 039c233

24 files changed

+334
-14
lines changed

examples/model-viewer/shaders.slang

+2-2
Original file line numberDiff line numberDiff line change
@@ -155,7 +155,7 @@ struct SimpleMaterial : IMaterial
155155
// To satisfy the requirements of the `IMaterial` interface, our
156156
// material type needs to provide a suitable `BRDF` type. We
157157
// do this by using a simple `typedef`, although a nested
158-
// `struct` type can also satisfy an assocaited type requirement.
158+
// `struct` type can also satisfy an associated type requirement.
159159
//
160160
// A future version of the Slang compiler may allow the "right"
161161
// associated type definition to be inferred from the signature
@@ -459,7 +459,7 @@ float4 fragmentMain(
459459
// from different light sources.
460460
//
461461
// Note that the return type here is `TMaterial.BRDF`,
462-
// which is the `BRDF` type *assocaited* with the (unknown)
462+
// which is the `BRDF` type *associated* with the (unknown)
463463
// `TMaterial` type. When `TMaterial` gets substituted for
464464
// a concrete type later (e.g., `SimpleMaterial`) this
465465
// will resolve to a concrete type too (e.g., `SimpleMaterial.BRDF`

slang.h

+2
Original file line numberDiff line numberDiff line change
@@ -1378,6 +1378,7 @@ extern "C"
13781378

13791379
SLANG_PARAMETER_CATEGORY_RAY_PAYLOAD,
13801380
SLANG_PARAMETER_CATEGORY_HIT_ATTRIBUTES,
1381+
SLANG_PARAMETER_CATEGORY_CALLABLE_PAYLOAD,
13811382

13821383
//
13831384
SLANG_PARAMETER_CATEGORY_COUNT,
@@ -1681,6 +1682,7 @@ namespace slang
16811682

16821683
RayPayload = SLANG_PARAMETER_CATEGORY_RAY_PAYLOAD,
16831684
HitAttributes = SLANG_PARAMETER_CATEGORY_HIT_ATTRIBUTES,
1685+
CallablePayload = SLANG_PARAMETER_CATEGORY_CALLABLE_PAYLOAD,
16841686

16851687
// DEPRECATED:
16861688
VertexInput = SLANG_PARAMETER_CATEGORY_VERTEX_INPUT,

source/slang-glslang/slang-glslang.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ static int glslang_compileGLSLToSPIRV(glslang_CompileRequest* request)
102102
CASE(ANY_HIT, AnyHitNV);
103103
CASE(CLOSEST_HIT, ClosestHitNV);
104104
CASE(MISS, MissNV);
105+
CASE(CALLABLE, CallableNV);
105106

106107
#undef CASE
107108

source/slang/core.meta.slang

+3
Original file line numberDiff line numberDiff line change
@@ -1233,6 +1233,9 @@ attribute_syntax [numthreads(x: int, y: int = 1, z: int = 1)] : NumThreadsAttr
12331233
__attributeTarget(VarDeclBase)
12341234
attribute_syntax [__vulkanRayPayload] : VulkanRayPayloadAttribute;
12351235

1236+
__attributeTarget(VarDeclBase)
1237+
attribute_syntax [__vulkanCallablePayload] : VulkanCallablePayloadAttribute;
1238+
12361239
__attributeTarget(VarDeclBase)
12371240
attribute_syntax [__vulkanHitAttributes] : VulkanHitAttributesAttribute;
12381241

source/slang/core.meta.slang.h

+3
Original file line numberDiff line numberDiff line change
@@ -1252,6 +1252,9 @@ SLANG_RAW("__attributeTarget(VarDeclBase)\n")
12521252
SLANG_RAW("attribute_syntax [__vulkanRayPayload] : VulkanRayPayloadAttribute;\n")
12531253
SLANG_RAW("\n")
12541254
SLANG_RAW("__attributeTarget(VarDeclBase)\n")
1255+
SLANG_RAW("attribute_syntax [__vulkanCallablePayload] : VulkanCallablePayloadAttribute;\n")
1256+
SLANG_RAW("\n")
1257+
SLANG_RAW("__attributeTarget(VarDeclBase)\n")
12551258
SLANG_RAW("attribute_syntax [__vulkanHitAttributes] : VulkanHitAttributesAttribute;\n")
12561259
SLANG_RAW("\n")
12571260
SLANG_RAW("__attributeTarget(FunctionDeclBase)\n")

source/slang/emit.cpp

+46-1
Original file line numberDiff line numberDiff line change
@@ -148,6 +148,7 @@ struct SharedEmitContext
148148
DiagnosticSink* getSink() { return &entryPoint->compileRequest->mSink; }
149149

150150
Dictionary<IRInst*, UInt> mapIRValueToRayPayloadLocation;
151+
Dictionary<IRInst*, UInt> mapIRValueToCallablePayloadLocation;
151152
};
152153

153154
struct EmitContext
@@ -3315,7 +3316,7 @@ struct EmitVisitor
33153316
case 'P':
33163317
{
33173318
// The `$XP` case handles looking up
3318-
// the assocaited `location` for a variable
3319+
// the associated `location` for a variable
33193320
// used as the argument ray payload at a
33203321
// trace call site.
33213322

@@ -3329,6 +3330,23 @@ struct EmitVisitor
33293330
}
33303331
break;
33313332

3333+
case 'C':
3334+
{
3335+
// The `$XC` case handles looking up
3336+
// the associated `location` for a variable
3337+
// used as the argument callable payload at a
3338+
// call site.
3339+
3340+
UInt argIndex = 0;
3341+
SLANG_RELEASE_ASSERT(argCount > argIndex);
3342+
auto arg = args[argIndex].get();
3343+
auto argLoad = as<IRLoad>(arg);
3344+
SLANG_RELEASE_ASSERT(argLoad);
3345+
auto argVar = argLoad->getOperand(0);
3346+
Emit(getCallablePayloadLocation(ctx, argVar));
3347+
}
3348+
break;
3349+
33323350
case 'T':
33333351
{
33343352
// The `$XT` case handles selecting between
@@ -5347,6 +5365,20 @@ struct EmitVisitor
53475365
return value;
53485366
}
53495367

5368+
UInt getCallablePayloadLocation(
5369+
EmitContext* ctx,
5370+
IRInst* inst)
5371+
{
5372+
auto& map = ctx->shared->mapIRValueToCallablePayloadLocation;
5373+
UInt value = 0;
5374+
if(map.TryGetValue(inst, value))
5375+
return value;
5376+
5377+
value = map.Count();
5378+
map.Add(inst, value);
5379+
return value;
5380+
}
5381+
53505382
void emitIRVarModifiers(
53515383
EmitContext* ctx,
53525384
VarLayout* layout,
@@ -5365,6 +5397,13 @@ struct EmitVisitor
53655397
emit(")\n");
53665398
emit("rayPayloadNV\n");
53675399
}
5400+
if(varDecl->findDecoration<IRVulkanCallablePayloadDecoration>())
5401+
{
5402+
emit("layout(location = ");
5403+
Emit(getCallablePayloadLocation(ctx, varDecl));
5404+
emit(")\n");
5405+
emit("callableDataNV\n");
5406+
}
53685407
if(varDecl->findDecoration<IRVulkanHitAttributesDecoration>())
53695408
{
53705409
emit("hitAttributeNV\n");
@@ -5520,6 +5559,12 @@ struct EmitVisitor
55205559
}
55215560
break;
55225561

5562+
case LayoutResourceKind::CallablePayload:
5563+
{
5564+
emit("callableDataInNV ");
5565+
}
5566+
break;
5567+
55235568
case LayoutResourceKind::HitAttributes:
55245569
{
55255570
emit("hitAttributeNV ");

source/slang/hlsl.meta.slang

+32-2
Original file line numberDiff line numberDiff line change
@@ -1355,8 +1355,38 @@ struct BuiltInTriangleIntersectionAttributes
13551355
// 10.3 - Intrinsics
13561356

13571357
// 10.3.1
1358-
__target_intrinsic(glsl, "callableShadersAreNotYetAvailableInVulkan")
1359-
void CallShader<param_t>(uint ShaderIndex, inout param_t Parameter);
1358+
1359+
void CallShader<Payload>(uint shaderIndex, inout Payload payload);
1360+
1361+
// `executeCallableNV` is the GLSL intrinsic that will be used to implement
1362+
// `CallShader()` for GLSL-based targets.
1363+
//
1364+
__target_intrinsic(glsl, "executeCallableNV")
1365+
void __executeCallableNV(uint shaderIndex, int payloadLocation);
1366+
1367+
// Next is the custom intrinsic that will compute the payload location
1368+
// for a type being used in a `CallShader()` call for GLSL-based targets.
1369+
//
1370+
__generic<Payload>
1371+
__target_intrinsic(glsl, "$XC")
1372+
[__readNone]
1373+
int __callablePayloadLocation(Payload payload);
1374+
1375+
// Now we provide a hard-coded definition of `CallShader()` for GLSL-based
1376+
// targets, which maps the generic HLSL operation into the non-generic
1377+
// GLSL equivalent.
1378+
//
1379+
__generic<Payload>
1380+
__specialized_for_target(glsl)
1381+
void CallShader(uint shaderIndex, inout Payload payload)
1382+
{
1383+
[__vulkanRayPayload]
1384+
static Payload p;
1385+
1386+
p = payload;
1387+
__executeCallableNV(shaderIndex, __callablePayloadLocation(p));
1388+
payload = p;
1389+
}
13601390

13611391
// 10.3.2
13621392
void TraceRay<payload_t>(

source/slang/hlsl.meta.slang.h

+32-2
Original file line numberDiff line numberDiff line change
@@ -1403,8 +1403,38 @@ SLANG_RAW("\n")
14031403
SLANG_RAW("// 10.3 - Intrinsics\n")
14041404
SLANG_RAW("\n")
14051405
SLANG_RAW("// 10.3.1\n")
1406-
SLANG_RAW("__target_intrinsic(glsl, \"callableShadersAreNotYetAvailableInVulkan\")\n")
1407-
SLANG_RAW("void CallShader<param_t>(uint ShaderIndex, inout param_t Parameter);\n")
1406+
SLANG_RAW("\n")
1407+
SLANG_RAW("void CallShader<Payload>(uint shaderIndex, inout Payload payload);\n")
1408+
SLANG_RAW("\n")
1409+
SLANG_RAW("// `executeCallableNV` is the GLSL intrinsic that will be used to implement\n")
1410+
SLANG_RAW("// `CallShader()` for GLSL-based targets.\n")
1411+
SLANG_RAW("//\n")
1412+
SLANG_RAW("__target_intrinsic(glsl, \"executeCallableNV\")\n")
1413+
SLANG_RAW("void __executeCallableNV(uint shaderIndex, int payloadLocation);\n")
1414+
SLANG_RAW("\n")
1415+
SLANG_RAW("// Next is the custom intrinsic that will compute the payload location\n")
1416+
SLANG_RAW("// for a type being used in a `CallShader()` call for GLSL-based targets.\n")
1417+
SLANG_RAW("//\n")
1418+
SLANG_RAW("__generic<Payload>\n")
1419+
SLANG_RAW("__target_intrinsic(glsl, \"$XC\")\n")
1420+
SLANG_RAW("[__readNone]\n")
1421+
SLANG_RAW("int __callablePayloadLocation(Payload payload);\n")
1422+
SLANG_RAW("\n")
1423+
SLANG_RAW("// Now we provide a hard-coded definition of `CallShader()` for GLSL-based\n")
1424+
SLANG_RAW("// targets, which maps the generic HLSL operation into the non-generic\n")
1425+
SLANG_RAW("// GLSL equivalent.\n")
1426+
SLANG_RAW("//\n")
1427+
SLANG_RAW("__generic<Payload>\n")
1428+
SLANG_RAW("__specialized_for_target(glsl)\n")
1429+
SLANG_RAW("void CallShader(uint shaderIndex, inout Payload payload)\n")
1430+
SLANG_RAW("{\n")
1431+
SLANG_RAW(" [__vulkanRayPayload]\n")
1432+
SLANG_RAW(" static Payload p;\n")
1433+
SLANG_RAW("\n")
1434+
SLANG_RAW(" p = payload;\n")
1435+
SLANG_RAW(" __executeCallableNV(shaderIndex, __callablePayloadLocation(p));\n")
1436+
SLANG_RAW(" payload = p;\n")
1437+
SLANG_RAW("}\n")
14081438
SLANG_RAW("\n")
14091439
SLANG_RAW("// 10.3.2\n")
14101440
SLANG_RAW("void TraceRay<payload_t>(\n")

source/slang/ir-insts.h

+9-1
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ struct IRVulkanRayPayloadDecoration : IRDecoration
125125
enum { kDecorationOp = kIRDecorationOp_VulkanRayPayload };
126126
};
127127

128+
/// A decoration that indicates that a variable represents
129+
/// a vulkan callable shader payload, and should have a location assigned
130+
/// to it.
131+
struct IRVulkanCallablePayloadDecoration : IRDecoration
132+
{
133+
enum { kDecorationOp = kIRDecorationOp_VulkanCallablePayload };
134+
};
135+
128136
/// A decoration that indicates that a variable represents
129137
/// vulkan hit attributes, and should have a location assigned
130138
/// to it.
@@ -978,7 +986,7 @@ IRGlobalValue* getSpecializedGlobalValueForDeclRef(
978986
struct ExtensionUsageTracker;
979987

980988
// Clone the IR values reachable from the given entry point
981-
// into the IR module assocaited with the specialization state.
989+
// into the IR module associated with the specialization state.
982990
// When multiple definitions of a symbol are found, the one
983991
// that is best specialized for the given `targetReq` will be
984992
// used.

source/slang/ir-serialize.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -704,6 +704,7 @@ Result IRSerialWriter::write(IRModule* module, SourceManager* sourceManager, Opt
704704
break;
705705
}
706706
case kIRDecorationOp_VulkanRayPayload:
707+
case kIRDecorationOp_VulkanCallablePayload:
707708
case kIRDecorationOp_VulkanHitAttributes:
708709
case kIRDecorationOp_ReadNone:
709710
{
@@ -1560,6 +1561,12 @@ IRDecoration* IRSerialReader::_createDecoration(const Ser::Inst& srcInst)
15601561
SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Empty);
15611562
return decor;
15621563
}
1564+
case kIRDecorationOp_VulkanCallablePayload:
1565+
{
1566+
auto decor = createEmptyDecoration<IRVulkanCallablePayloadDecoration>(m_module);
1567+
SLANG_ASSERT(srcInst.m_payloadType == PayloadType::Empty);
1568+
return decor;
1569+
}
15631570
case kIRDecorationOp_VulkanHitAttributes:
15641571
{
15651572
auto decor = createEmptyDecoration<IRVulkanHitAttributesDecoration>(m_module);

source/slang/ir.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -2934,6 +2934,11 @@ namespace Slang
29342934
dump(context, "\n[__vulkanRayPayload]");
29352935
}
29362936
break;
2937+
case kIRDecorationOp_VulkanCallablePayload:
2938+
{
2939+
dump(context, "\n[__vulkanCallPayload]");
2940+
}
2941+
break;
29372942
case kIRDecorationOp_VulkanHitAttributes:
29382943
{
29392944
dump(context, "\n[__vulkanHitAttributes]");
@@ -5377,6 +5382,12 @@ namespace Slang
53775382
}
53785383
break;
53795384

5385+
case kIRDecorationOp_VulkanCallablePayload:
5386+
{
5387+
context->builder->addDecoration<IRVulkanCallablePayloadDecoration>(clonedValue);
5388+
}
5389+
break;
5390+
53805391
case kIRDecorationOp_VulkanHitAttributes:
53815392
{
53825393
context->builder->addDecoration<IRVulkanHitAttributesDecoration>(clonedValue);

source/slang/ir.h

+1
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ enum IRDecorationOp : uint16_t
158158
kIRDecorationOp_RequireGLSLVersion,
159159
kIRDecorationOp_RequireGLSLExtension,
160160
kIRDecorationOp_ReadNone,
161+
kIRDecorationOp_VulkanCallablePayload,
161162

162163
kIRDecorationOp_CountOf
163164
};

source/slang/lower-to-ir.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1260,6 +1260,10 @@ void addVarDecorations(
12601260
{
12611261
builder->addDecoration<IRVulkanRayPayloadDecoration>(inst);
12621262
}
1263+
else if(mod.As<VulkanCallablePayloadAttribute>())
1264+
{
1265+
builder->addDecoration<IRVulkanCallablePayloadDecoration>(inst);
1266+
}
12631267
else if(mod.As<VulkanHitAttributesAttribute>())
12641268
{
12651269
builder->addDecoration<IRVulkanHitAttributesDecoration>(inst);

source/slang/modifier-defs.h

+6
Original file line numberDiff line numberDiff line change
@@ -393,6 +393,12 @@ END_SYNTAX_CLASS()
393393
// ray tracing shader to pass per-ray payload information.
394394
SIMPLE_SYNTAX_CLASS(VulkanRayPayloadAttribute, Attribute)
395395

396+
// A `[__vulkanCallablePayload]` attribute, which is used in the
397+
// standard library implementation to indicate that a variable
398+
// actually represents the input/output interface for a Vulkan
399+
// ray tracing shader to pass payload information to/from a callee.
400+
SIMPLE_SYNTAX_CLASS(VulkanCallablePayloadAttribute, Attribute)
401+
396402
// A `[__vulkanHitAttributes]` attribute, which is used in the
397403
// standard library implementation to indicate that a variable
398404
// actually represents the output interface for a Vulkan

source/slang/options.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ struct OptionsParser
103103

104104
// An entry point represents a function to be checked and possibly have
105105
// code generated in one of our translation units. An entry point
106-
// needs to have an assocaited stage, which might come via the
106+
// needs to have an associated stage, which might come via the
107107
// `-stage` command line option, or a `[shader("...")]` attribute
108108
// in the source code.
109109
//
@@ -1176,7 +1176,7 @@ struct OptionsParser
11761176
// need to support output formats that can store multiple
11771177
// entry points in one file).
11781178

1179-
// If an output doesn't have a target assocaited with
1179+
// If an output doesn't have a target associated with
11801180
// it, then search for the target with the matching format.
11811181
if( rawOutput.targetIndex == -1 )
11821182
{

source/slang/parameter-binding.cpp

+7-1
Original file line numberDiff line numberDiff line change
@@ -1910,13 +1910,19 @@ static RefPtr<TypeLayout> processEntryPointParameter(
19101910
break;
19111911

19121912
case Stage::AnyHit:
1913-
case Stage::Callable:
19141913
case Stage::ClosestHit:
19151914
case Stage::Miss:
19161915
// `in out` or `out` parameter is payload
19171916
return CreateTypeLayout(context->layoutContext.with(
19181917
context->getRulesFamily()->getRayPayloadParameterRules()),
19191918
type);
1919+
1920+
case Stage::Callable:
1921+
// `in out` or `out` parameter is payload
1922+
return CreateTypeLayout(context->layoutContext.with(
1923+
context->getRulesFamily()->getCallablePayloadParameterRules()),
1924+
type);
1925+
19201926
}
19211927
}
19221928
else

source/slang/preprocessor.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -478,7 +478,7 @@ static PreprocessorMacro* LookupMacro(PreprocessorEnvironment* environment, Name
478478

479479
static PreprocessorEnvironment* GetCurrentEnvironment(Preprocessor* preprocessor)
480480
{
481-
// The environment we will use for looking up a macro is assocaited
481+
// The environment we will use for looking up a macro is associated
482482
// with the current input stream (because it may include entries
483483
// for macro arguments).
484484
//

0 commit comments

Comments
 (0)