Skip to content

Commit d4902e9

Browse files
authored
Merge branch 'master' into master
2 parents 7c3b729 + 551bbb5 commit d4902e9

18 files changed

+315
-40
lines changed

source/compiler-core/slang-dxc-compiler.cpp

+7
Original file line numberDiff line numberDiff line change
@@ -479,6 +479,13 @@ SlangResult DXCDownstreamCompiler::compile(const CompileOptions& inOptions, IArt
479479
args.add(compilerSpecific[i]);
480480
}
481481

482+
// This can be re-enabled when we add PAQs: https://github.com/shader-slang/slang/issues/3448
483+
const bool enablePAQs = false;
484+
if (!enablePAQs)
485+
args.add(L"-disable-payload-qualifiers");
486+
else
487+
args.add(L"-enable-payload-qualifiers");
488+
482489
// TODO: deal with
483490
bool treatWarningsAsErrors = false;
484491
if (treatWarningsAsErrors)

source/slang/hlsl.meta.slang

+12-5
Original file line numberDiff line numberDiff line change
@@ -16464,6 +16464,13 @@ __generic<T>
1646416464
__intrinsic_op($(kIROp_ForceVarIntoStructTemporarily))
1646516465
Ref<T> __forceVarIntoStructTemporarily(inout T maybeStruct);
1646616466

16467+
// Some functions require a struct type which is decorated with a [raypayload]
16468+
// attribute. This will do the same as __forceVarIntoStructTemporarily and also
16469+
// ensure that the struct type in question is decorated appropriately.
16470+
__generic<T>
16471+
__intrinsic_op($(kIROp_ForceVarIntoRayPayloadStructTemporarily))
16472+
Ref<T> __forceVarIntoRayPayloadStructTemporarily(inout T maybeStruct);
16473+
1646716474
__generic<payload_t>
1646816475
[require(hlsl, raytracing)]
1646916476
void __traceRayHLSL(
@@ -16548,7 +16555,7 @@ void TraceRay(
1654816555
MultiplierForGeometryContributionToHitGroupIndex,
1654916556
MissShaderIndex,
1655016557
Ray,
16551-
__forceVarIntoStructTemporarily(Payload));
16558+
__forceVarIntoRayPayloadStructTemporarily(Payload));
1655216559
return;
1655316560
case cuda: __intrinsic_asm "traceOptiXRay";
1655416561
case glsl:
@@ -16686,7 +16693,7 @@ void TraceMotionRay(
1668616693
MissShaderIndex,
1668716694
Ray,
1668816695
CurrentTime,
16689-
__forceVarIntoStructTemporarily(Payload));
16696+
__forceVarIntoRayPayloadStructTemporarily(Payload));
1669016697
return;
1669116698
case glsl:
1669216699
{
@@ -18830,7 +18837,7 @@ struct HitObject
1883018837
MultiplierForGeometryContributionToHitGroupIndex,
1883118838
MissShaderIndex,
1883218839
Ray,
18833-
__forceVarIntoStructTemporarily(Payload),
18840+
__forceVarIntoRayPayloadStructTemporarily(Payload),
1883418841
hitObj);
1883518842
return hitObj;
1883618843
}
@@ -18923,7 +18930,7 @@ struct HitObject
1892318930
MissShaderIndex,
1892418931
Ray,
1892518932
CurrentTime,
18926-
__forceVarIntoStructTemporarily(Payload));
18933+
__forceVarIntoRayPayloadStructTemporarily(Payload));
1892718934
case glsl:
1892818935
{
1892918936
[__vulkanRayPayload]
@@ -19441,7 +19448,7 @@ struct HitObject
1944119448
__InvokeHLSL(
1944219449
AccelerationStructure,
1944319450
HitOrMiss,
19444-
__forceVarIntoStructTemporarily(Payload));
19451+
__forceVarIntoRayPayloadStructTemporarily(Payload));
1944519452
case glsl:
1944619453
{
1944719454
[__vulkanRayPayload]

source/slang/slang-check-decl.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -12114,6 +12114,27 @@ static void checkDerivativeAttribute(
1211412114
imaginaryArguments.directions,
1211512115
imaginaryArguments.thisArg,
1211612116
imaginaryArguments.thisArgDirection);
12117+
12118+
// For primal-substitute we'd also want to make sure that the differentiability
12119+
// level of the target is as high as the funcDecl itself
12120+
//
12121+
if (auto declRefExpr = as<DeclRefExpr>(attr->funcExpr))
12122+
{
12123+
if (auto declRef = declRefExpr->declRef)
12124+
{
12125+
auto targetDiffLevel = visitor->getShared()->getFuncDifferentiableLevel(
12126+
declRef.as<FunctionDeclBase>().getDecl());
12127+
auto currDiffLevel = visitor->getShared()->getFuncDifferentiableLevel(funcDecl);
12128+
if (targetDiffLevel < currDiffLevel)
12129+
{
12130+
visitor->getSink()->diagnose(
12131+
attr->loc,
12132+
Diagnostics::primalSubstituteTargetMustHaveHigherDifferentiabilityLevel,
12133+
declRefExpr->declRef.getDecl(),
12134+
funcDecl);
12135+
}
12136+
}
12137+
}
1211712138
}
1211812139

1211912140
static void checkCudaKernelAttribute(

source/slang/slang-diagnostic-defs.h

+6
Original file line numberDiff line numberDiff line change
@@ -1208,6 +1208,12 @@ DIAGNOSTIC(
12081208
Error,
12091209
overloadedFuncUsedWithDerivativeOfAttributes,
12101210
"cannot resolve overloaded functions for derivative-of attributes.")
1211+
DIAGNOSTIC(
1212+
31158,
1213+
Error,
1214+
primalSubstituteTargetMustHaveHigherDifferentiabilityLevel,
1215+
"primal substitute function for differentiable method must also be differentiable. Use "
1216+
"[Differentiable] or [TreatAsDifferentiable] (for empty derivatives)")
12111217

12121218
DIAGNOSTIC(31200, Warning, deprecatedUsage, "$0 has been deprecated: $1")
12131219
DIAGNOSTIC(31201, Error, modifierNotAllowed, "modifier '$0' is not allowed here.")

source/slang/slang-emit-hlsl.cpp

+9
Original file line numberDiff line numberDiff line change
@@ -1669,6 +1669,15 @@ void HLSLSourceEmitter::emitPostKeywordTypeAttributesImpl(IRInst* inst)
16691669
{
16701670
m_writer->emit("[payload] ");
16711671
}
1672+
// This can be re-enabled when we add PAQs: https://github.com/shader-slang/slang/issues/3448
1673+
const bool enablePAQs = false;
1674+
if (enablePAQs)
1675+
{
1676+
if (const auto payloadDecoration = inst->findDecoration<IRRayPayloadDecoration>())
1677+
{
1678+
m_writer->emit("[raypayload] ");
1679+
}
1680+
}
16721681
}
16731682

16741683
void HLSLSourceEmitter::_emitPrefixTypeAttr(IRAttr* attr)

source/slang/slang-emit-spirv.cpp

+34-13
Original file line numberDiff line numberDiff line change
@@ -3192,6 +3192,17 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
31923192
void ensureAtomicCapability(IRInst* atomicInst, SpvOp op)
31933193
{
31943194
auto typeOp = atomicInst->getDataType()->getOp();
3195+
if (typeOp == kIROp_VoidType)
3196+
{
3197+
auto ptrType = atomicInst->getOperand(0)->getDataType();
3198+
IRBuilder builder(atomicInst);
3199+
if (auto valType = tryGetPointedToType(&builder, ptrType))
3200+
{
3201+
if (auto atomicType = as<IRAtomicType>(valType))
3202+
valType = atomicType->getElementType();
3203+
typeOp = valType->getOp();
3204+
}
3205+
}
31953206
switch (op)
31963207
{
31973208
case SpvOpAtomicFAddEXT:
@@ -5094,18 +5105,23 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
50945105
{
50955106
SpvBuiltIn builtinName;
50965107
SpvStorageClass storageClass = SpvStorageClassInput;
5108+
bool flat = false;
50975109
BuiltinSpvVarKey() = default;
5098-
BuiltinSpvVarKey(SpvBuiltIn builtin, SpvStorageClass storageClass)
5099-
: builtinName(builtin), storageClass(storageClass)
5110+
BuiltinSpvVarKey(SpvBuiltIn builtin, SpvStorageClass storageClass, bool isFlat)
5111+
: builtinName(builtin), storageClass(storageClass), flat(isFlat)
51005112
{
51015113
}
51025114
bool operator==(const BuiltinSpvVarKey& other) const
51035115
{
5104-
return builtinName == other.builtinName && storageClass == other.storageClass;
5116+
return builtinName == other.builtinName && storageClass == other.storageClass &&
5117+
flat == other.flat;
51055118
}
51065119
HashCode getHashCode() const
51075120
{
5108-
return combineHash(Slang::getHashCode(builtinName), Slang::getHashCode(storageClass));
5121+
return combineHash(
5122+
Slang::getHashCode(builtinName),
5123+
Slang::getHashCode(storageClass),
5124+
Slang::getHashCode(flat));
51095125
}
51105126
};
51115127
Dictionary<BuiltinSpvVarKey, SpvInst*> m_builtinGlobalVars;
@@ -5127,26 +5143,25 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
51275143
return false;
51285144
}
51295145

5130-
void maybeEmitFlatDecorationForBuiltinVar(IRInst* irInst, SpvInst* spvInst)
5146+
bool needFlatDecorationForBuiltinVar(IRInst* irInst)
51315147
{
51325148
if (!irInst)
5133-
return;
5149+
return false;
51345150
if (irInst->getOp() != kIROp_GlobalVar && irInst->getOp() != kIROp_GlobalParam)
5135-
return;
5151+
return false;
51365152
auto ptrType = as<IRPtrType>(irInst->getDataType());
51375153
if (!ptrType)
5138-
return;
5154+
return false;
51395155
auto addrSpace = ptrType->getAddressSpace();
51405156
if (addrSpace == AddressSpace::Input || addrSpace == AddressSpace::BuiltinInput)
51415157
{
51425158
if (isIntegralScalarOrCompositeType(ptrType->getValueType()))
51435159
{
51445160
if (isInstUsedInStage(irInst, Stage::Fragment))
5145-
_maybeEmitInterpolationModifierDecoration(
5146-
IRInterpolationMode::NoInterpolation,
5147-
getID(spvInst));
5161+
return true;
51485162
}
51495163
}
5164+
return false;
51505165
}
51515166

51525167
SpvInst* getBuiltinGlobalVar(IRType* type, SpvBuiltIn builtinVal, IRInst* irInst)
@@ -5155,7 +5170,8 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
51555170
auto ptrType = as<IRPtrTypeBase>(type);
51565171
SLANG_ASSERT(ptrType && "`getBuiltinGlobalVar`: `type` must be ptr type.");
51575172
auto storageClass = addressSpaceToStorageClass(ptrType->getAddressSpace());
5158-
auto key = BuiltinSpvVarKey(builtinVal, storageClass);
5173+
bool isFlat = needFlatDecorationForBuiltinVar(irInst);
5174+
auto key = BuiltinSpvVarKey(builtinVal, storageClass, isFlat);
51595175
if (m_builtinGlobalVars.tryGetValue(key, result))
51605176
{
51615177
return result;
@@ -5185,7 +5201,12 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
51855201
}
51865202
m_builtinGlobalVars[key] = varInst;
51875203

5188-
maybeEmitFlatDecorationForBuiltinVar(irInst, varInst);
5204+
if (isFlat)
5205+
{
5206+
_maybeEmitInterpolationModifierDecoration(
5207+
IRInterpolationMode::NoInterpolation,
5208+
getID(varInst));
5209+
}
51895210

51905211
return varInst;
51915212
}

source/slang/slang-ir-autodiff.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -2471,6 +2471,7 @@ void stripAutoDiffDecorationsFromChildren(IRInst* parent)
24712471
case kIROp_BackwardDerivativePrimalDecoration:
24722472
case kIROp_BackwardDerivativePrimalContextDecoration:
24732473
case kIROp_BackwardDerivativePrimalReturnDecoration:
2474+
case kIROp_PrimalSubstituteDecoration:
24742475
case kIROp_AutoDiffOriginalValueDecoration:
24752476
case kIROp_UserDefinedBackwardDerivativeDecoration:
24762477
case kIROp_IntermediateContextFieldDifferentialTypeDecoration:

source/slang/slang-ir-hlsl-legalize.cpp

+17-8
Original file line numberDiff line numberDiff line change
@@ -29,14 +29,20 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in
2929
for (UInt i = 0; i < call->getArgCount(); i++)
3030
{
3131
auto arg = call->getArg(i);
32-
if (arg->getOp() != kIROp_ForceVarIntoStructTemporarily)
32+
const bool isForcedStruct = arg->getOp() == kIROp_ForceVarIntoStructTemporarily;
33+
const bool isForcedRayPayloadStruct =
34+
arg->getOp() == kIROp_ForceVarIntoRayPayloadStructTemporarily;
35+
if (!(isForcedStruct || isForcedRayPayloadStruct))
3336
continue;
3437
auto forceStructArg = arg->getOperand(0);
3538
auto forceStructBaseType =
3639
as<IRType>(forceStructArg->getDataType()->getOperand(0));
40+
IRBuilder builder(call);
3741
if (forceStructBaseType->getOp() == kIROp_StructType)
3842
{
3943
call->setArg(i, arg->getOperand(0));
44+
if (isForcedRayPayloadStruct)
45+
builder.addRayPayloadDecoration(forceStructBaseType);
4046
continue;
4147
}
4248

@@ -47,14 +53,19 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in
4753
// `__forceVarIntoStructTemporarily` is a parameter to a side effect type
4854
// (`ref`, `out`, `inout`) we copy the struct back into our original non-struct
4955
// parameter.
50-
IRBuilder builder(call);
56+
57+
const auto typeNameHint = isForcedRayPayloadStruct
58+
? "RayPayload_t"
59+
: "ForceVarIntoStructTemporarily_t";
60+
const auto varNameHint =
61+
isForcedRayPayloadStruct ? "rayPayload" : "forceVarIntoStructTemporarily";
5162

5263
builder.setInsertBefore(call->getCallee());
5364
auto structType = builder.createStructType();
5465
StringBuilder structName;
55-
builder.addNameHintDecoration(
56-
structType,
57-
UnownedStringSlice("ForceVarIntoStructTemporarily_t"));
66+
builder.addNameHintDecoration(structType, UnownedStringSlice(typeNameHint));
67+
if (isForcedRayPayloadStruct)
68+
builder.addRayPayloadDecoration(structType);
5869

5970
auto elementBufferKey = builder.createStructKey();
6071
builder.addNameHintDecoration(elementBufferKey, UnownedStringSlice("data"));
@@ -65,9 +76,7 @@ void searchChildrenForForceVarIntoStructTemporarily(IRModule* module, IRInst* in
6576

6677
builder.setInsertBefore(call);
6778
auto structVar = builder.emitVar(structType);
68-
builder.addNameHintDecoration(
69-
structVar,
70-
UnownedStringSlice("forceVarIntoStructTemporarily"));
79+
builder.addNameHintDecoration(structVar, UnownedStringSlice(varNameHint));
7180
builder.emitStore(
7281
builder.emitFieldAddress(
7382
builder.getPtrType(_dataField->getFieldType()),

source/slang/slang-ir-inst-defs.h

+4
Original file line numberDiff line numberDiff line change
@@ -759,6 +759,9 @@ INST(GetPerVertexInputArray, GetPerVertexInputArray, 1, HOISTABLE)
759759
INST(ResolveVaryingInputRef, ResolveVaryingInputRef, 1, HOISTABLE)
760760

761761
INST(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, 1, 0)
762+
INST(ForceVarIntoRayPayloadStructTemporarily, ForceVarIntoRayPayloadStructTemporarily, 1, 0)
763+
INST_RANGE(ForceVarIntoStructTemporarily, ForceVarIntoStructTemporarily, ForceVarIntoRayPayloadStructTemporarily)
764+
762765
INST(MetalAtomicCast, MetalAtomicCast, 1, 0)
763766

764767
INST(IsTextureAccess, IsTextureAccess, 1, 0)
@@ -992,6 +995,7 @@ INST_RANGE(BindingQuery, GetRegisterIndex, GetRegisterSpace)
992995
INST(GLSLLocationDecoration, glslLocation, 1, 0)
993996
INST(GLSLOffsetDecoration, glslOffset, 1, 0)
994997
INST(PayloadDecoration, payload, 0, 0)
998+
INST(RayPayloadDecoration, raypayload, 0, 0)
995999

9961000
/* Mesh Shader outputs */
9971001
INST(VerticesDecoration, vertices, 1, 0)

source/slang/slang-ir-insts.h

+7
Original file line numberDiff line numberDiff line change
@@ -1605,6 +1605,11 @@ struct IRPayloadDecoration : public IRDecoration
16051605
IR_LEAF_ISA(PayloadDecoration)
16061606
};
16071607

1608+
struct IRRayPayloadDecoration : public IRDecoration
1609+
{
1610+
IR_LEAF_ISA(RayPayloadDecoration)
1611+
};
1612+
16081613
// Mesh shader decorations
16091614

16101615
struct IRMeshOutputDecoration : public IRDecoration
@@ -5289,6 +5294,8 @@ struct IRBuilder
52895294
{
52905295
addDecoration(inst, kIROp_EntryPointParamDecoration, entryPointFunc);
52915296
}
5297+
5298+
void addRayPayloadDecoration(IRType* inst) { addDecoration(inst, kIROp_RayPayloadDecoration); }
52925299
};
52935300

52945301
// Helper to establish the source location that will be used
+46
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type -g0
2+
3+
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBufferPrimal
4+
RWStructuredBuffer<float> outputBufferPrimal;
5+
6+
//TEST_INPUT:ubuffer(data=[1 2 3 4], stride=4):name=gradBuffer
7+
RWStructuredBuffer<float> gradBuffer;
8+
9+
struct BufferWithGrad
10+
{
11+
RWStructuredBuffer<float> primal;
12+
RWStructuredBuffer<float> grad;
13+
14+
[Differentiable]
15+
void add(float value) { primal[0] = primal[0] + detach(value); }
16+
17+
[PrimalSubstituteOf(add), Differentiable]
18+
void add_subst(float value)
19+
{
20+
}
21+
22+
[BackwardDerivativeOf(add)]
23+
void add_bwd(inout DifferentialPair<float> d)
24+
{
25+
d = diffPair(d.p, grad[0]);
26+
}
27+
}
28+
29+
[Differentiable]
30+
void diffCall(BufferWithGrad result)
31+
{
32+
result.add(1.0f);
33+
}
34+
35+
[shader("compute")]
36+
[numthreads(1, 1, 1)]
37+
void computeMain(uint3 dispatchThreadID: SV_DispatchThreadID)
38+
{
39+
BufferWithGrad bg = {outputBufferPrimal, gradBuffer};
40+
diffCall(bg);
41+
bwd_diff(diffCall)(bg);
42+
43+
// CHECK: type: float
44+
// CHECK-NEXT: 1.0
45+
// CHECK-NEXT: 0.0
46+
}

0 commit comments

Comments
 (0)