Skip to content

Commit 7c2ff54

Browse files
csyongheslangbot
andauthored
Various WGSL fixes. (shader-slang#5490)
* [WGSL] make sure switch has a default label. * Various WGSL fixes. * Update rhi submodule commit * format code * Remove unnecessary DISABLE_TEST directive on not applicable test. * Matrix comp mul + `select`. * Legalize binary ops for wgsl. --------- Co-authored-by: slangbot <186143334+slangbot@users.noreply.github.com>
1 parent 2c8dacf commit 7c2ff54

14 files changed

+231
-13
lines changed

source/slang/slang-emit-wgsl.cpp

+121-5
Original file line numberDiff line numberDiff line change
@@ -497,6 +497,34 @@ void WGSLSourceEmitter::emitLayoutQualifiersImpl(IRVarLayout* layout)
497497
}
498498
}
499499

500+
static bool isStaticConst(IRInst* inst)
501+
{
502+
if (inst->getParent()->getOp() == kIROp_Module)
503+
{
504+
return true;
505+
}
506+
switch (inst->getOp())
507+
{
508+
case kIROp_MakeVector:
509+
case kIROp_swizzle:
510+
case kIROp_swizzleSet:
511+
case kIROp_IntCast:
512+
case kIROp_FloatCast:
513+
case kIROp_CastFloatToInt:
514+
case kIROp_CastIntToFloat:
515+
case kIROp_BitCast:
516+
{
517+
for (UInt i = 0; i < inst->getOperandCount(); i++)
518+
{
519+
if (!isStaticConst(inst->getOperand(i)))
520+
return false;
521+
}
522+
return true;
523+
}
524+
}
525+
return false;
526+
}
527+
500528
void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
501529
{
502530
switch (varDecl->getOp())
@@ -505,14 +533,10 @@ void WGSLSourceEmitter::emitVarKeywordImpl(IRType* type, IRInst* varDecl)
505533
case kIROp_GlobalVar:
506534
case kIROp_Var: m_writer->emit("var"); break;
507535
default:
508-
if (as<IRModuleInst>(varDecl->getParent()))
509-
{
536+
if (isStaticConst(varDecl))
510537
m_writer->emit("const");
511-
}
512538
else
513-
{
514539
m_writer->emit("var");
515-
}
516540
break;
517541
}
518542

@@ -977,6 +1001,33 @@ void WGSLSourceEmitter::emitCallArg(IRInst* inst)
9771001
}
9781002
}
9791003

1004+
bool WGSLSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
1005+
{
1006+
bool result = CLikeSourceEmitter::shouldFoldInstIntoUseSites(inst);
1007+
if (result)
1008+
{
1009+
// If inst is a matrix, and is used in a component-wise multiply,
1010+
// we need to not fold it.
1011+
if (as<IRMatrixType>(inst->getDataType()))
1012+
{
1013+
for (auto use = inst->firstUse; use; use = use->nextUse)
1014+
{
1015+
auto user = use->getUser();
1016+
if (user->getOp() == kIROp_Mul)
1017+
{
1018+
if (as<IRMatrixType>(user->getOperand(0)->getDataType()) &&
1019+
as<IRMatrixType>(user->getOperand(1)->getDataType()))
1020+
{
1021+
return false;
1022+
}
1023+
}
1024+
}
1025+
}
1026+
}
1027+
return result;
1028+
}
1029+
1030+
9801031
bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOuterPrec)
9811032
{
9821033
EmitOpInfo outerPrec = inOuterPrec;
@@ -1126,6 +1177,71 @@ bool WGSLSourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
11261177
return true;
11271178
}
11281179
break;
1180+
1181+
case kIROp_GetStringHash:
1182+
{
1183+
auto getStringHashInst = as<IRGetStringHash>(inst);
1184+
auto stringLit = getStringHashInst->getStringLit();
1185+
1186+
if (stringLit)
1187+
{
1188+
auto slice = stringLit->getStringSlice();
1189+
emitType(inst->getDataType());
1190+
m_writer->emit("(");
1191+
m_writer->emit((int)getStableHashCode32(slice.begin(), slice.getLength()).hash);
1192+
m_writer->emit(")");
1193+
}
1194+
else
1195+
{
1196+
// Couldn't handle
1197+
diagnoseUnhandledInst(inst);
1198+
}
1199+
return true;
1200+
}
1201+
1202+
case kIROp_Mul:
1203+
{
1204+
if (!as<IRMatrixType>(inst->getOperand(0)->getDataType()) ||
1205+
!as<IRMatrixType>(inst->getOperand(1)->getDataType()))
1206+
{
1207+
return false;
1208+
}
1209+
// Mul(m1, m2) should be translated to component-wise multiplication in WGSL.
1210+
auto matrixType = as<IRMatrixType>(inst->getDataType());
1211+
auto rowCount = getIntVal(matrixType->getRowCount());
1212+
emitType(inst->getDataType());
1213+
m_writer->emit("(");
1214+
for (IRIntegerValue i = 0; i < rowCount; i++)
1215+
{
1216+
if (i != 0)
1217+
{
1218+
m_writer->emit(", ");
1219+
}
1220+
emitOperand(inst->getOperand(0), getInfo(EmitOp::Postfix));
1221+
m_writer->emit("[");
1222+
m_writer->emit(i);
1223+
m_writer->emit("] * ");
1224+
emitOperand(inst->getOperand(1), getInfo(EmitOp::Postfix));
1225+
m_writer->emit("[");
1226+
m_writer->emit(i);
1227+
m_writer->emit("]");
1228+
}
1229+
m_writer->emit(")");
1230+
1231+
return true;
1232+
}
1233+
1234+
case kIROp_Select:
1235+
{
1236+
m_writer->emit("select(");
1237+
emitOperand(inst->getOperand(2), getInfo(EmitOp::General));
1238+
m_writer->emit(", ");
1239+
emitOperand(inst->getOperand(1), getInfo(EmitOp::General));
1240+
m_writer->emit(", ");
1241+
emitOperand(inst->getOperand(0), getInfo(EmitOp::General));
1242+
m_writer->emit(")");
1243+
return true;
1244+
}
11291245
}
11301246

11311247
return false;

source/slang/slang-emit-wgsl.h

+2
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,8 @@ class WGSLSourceEmitter : public CLikeSourceEmitter
5050

5151
void emit(const AddressSpace addressSpace);
5252

53+
virtual bool shouldFoldInstIntoUseSites(IRInst* inst) SLANG_OVERRIDE;
54+
5355
private:
5456
// Emit the matrix type with 'rowCountWGSL' WGSL-rows and 'colCountWGSL' WGSL-columns
5557
void emitMatrixType(

source/slang/slang-ir-insts.h

+1
Original file line numberDiff line numberDiff line change
@@ -4021,6 +4021,7 @@ struct IRBuilder
40214021
IRInst* emitDifferentialPairGetPrimalUserCode(IRInst* diffPair);
40224022
IRInst* emitMakeVector(IRType* type, UInt argCount, IRInst* const* args);
40234023
IRInst* emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue);
4024+
IRInst* emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue);
40244025

40254026
IRInst* emitMakeVector(IRType* type, List<IRInst*> const& args)
40264027
{

source/slang/slang-ir-wgsl-legalize.cpp

+89-1
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,8 @@ struct LegalizeWGSLEntryPointContext
5151
String* optionalSemanticIndex,
5252
IRInst* parentVar);
5353
void legalizeCall(IRCall* call);
54+
void legalizeSwitch(IRSwitch* switchInst);
55+
void legalizeBinaryOp(IRInst* inst);
5456
void processInst(IRInst* inst);
5557
};
5658

@@ -349,11 +351,97 @@ void LegalizeWGSLEntryPointContext::legalizeCall(IRCall* call)
349351
}
350352
}
351353

354+
void LegalizeWGSLEntryPointContext::legalizeSwitch(IRSwitch* switchInst)
355+
{
356+
// WGSL Requires all switch statements to contain a default case.
357+
// If the switch statement does not contain a default case, we will add one.
358+
if (switchInst->getDefaultLabel() != switchInst->getBreakLabel())
359+
return;
360+
IRBuilder builder(switchInst);
361+
auto defaultBlock = builder.createBlock();
362+
builder.setInsertInto(defaultBlock);
363+
builder.emitBranch(switchInst->getBreakLabel());
364+
defaultBlock->insertBefore(switchInst->getBreakLabel());
365+
List<IRInst*> cases;
366+
for (UInt i = 0; i < switchInst->getCaseCount(); i++)
367+
{
368+
cases.add(switchInst->getCaseValue(i));
369+
cases.add(switchInst->getCaseLabel(i));
370+
}
371+
builder.setInsertBefore(switchInst);
372+
auto newSwitch = builder.emitSwitch(
373+
switchInst->getCondition(),
374+
switchInst->getBreakLabel(),
375+
defaultBlock,
376+
(UInt)cases.getCount(),
377+
cases.getBuffer());
378+
switchInst->transferDecorationsTo(newSwitch);
379+
switchInst->removeAndDeallocate();
380+
}
381+
382+
void LegalizeWGSLEntryPointContext::legalizeBinaryOp(IRInst* inst)
383+
{
384+
auto isVectorOrMatrix = [](IRType* type)
385+
{
386+
switch (type->getOp())
387+
{
388+
case kIROp_VectorType:
389+
case kIROp_MatrixType: return true;
390+
default: return false;
391+
}
392+
};
393+
if (isVectorOrMatrix(inst->getOperand(0)->getDataType()) &&
394+
as<IRBasicType>(inst->getOperand(1)->getDataType()))
395+
{
396+
IRBuilder builder(inst);
397+
builder.setInsertBefore(inst);
398+
auto newRhs = builder.emitMakeCompositeFromScalar(
399+
inst->getOperand(0)->getDataType(),
400+
inst->getOperand(1));
401+
builder.replaceOperand(inst->getOperands() + 1, newRhs);
402+
}
403+
else if (
404+
as<IRBasicType>(inst->getOperand(0)->getDataType()) &&
405+
isVectorOrMatrix(inst->getOperand(1)->getDataType()))
406+
{
407+
IRBuilder builder(inst);
408+
builder.setInsertBefore(inst);
409+
auto newLhs = builder.emitMakeCompositeFromScalar(
410+
inst->getOperand(1)->getDataType(),
411+
inst->getOperand(0));
412+
builder.replaceOperand(inst->getOperands(), newLhs);
413+
}
414+
}
415+
352416
void LegalizeWGSLEntryPointContext::processInst(IRInst* inst)
353417
{
354418
switch (inst->getOp())
355419
{
356-
case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
420+
case kIROp_Call: legalizeCall(static_cast<IRCall*>(inst)); break;
421+
case kIROp_Switch: legalizeSwitch(as<IRSwitch>(inst)); break;
422+
423+
// For all binary operators, make sure both side of the operator have the same type
424+
// (vector-ness and matrix-ness).
425+
case kIROp_Add:
426+
case kIROp_Sub:
427+
case kIROp_Mul:
428+
case kIROp_Div:
429+
case kIROp_FRem:
430+
case kIROp_IRem:
431+
case kIROp_And:
432+
case kIROp_Or:
433+
case kIROp_BitAnd:
434+
case kIROp_BitOr:
435+
case kIROp_BitXor:
436+
case kIROp_Lsh:
437+
case kIROp_Rsh:
438+
case kIROp_Eql:
439+
case kIROp_Neq:
440+
case kIROp_Greater:
441+
case kIROp_Less:
442+
case kIROp_Geq:
443+
case kIROp_Leq: legalizeBinaryOp(inst); break;
444+
357445
default:
358446
for (auto child : inst->getModifiableChildren())
359447
processInst(child);

source/slang/slang-ir.cpp

+11
Original file line numberDiff line numberDiff line change
@@ -4162,6 +4162,17 @@ IRInst* IRBuilder::emitMakeVectorFromScalar(IRType* type, IRInst* scalarValue)
41624162
return emitIntrinsicInst(type, kIROp_MakeVectorFromScalar, 1, &scalarValue);
41634163
}
41644164

4165+
IRInst* IRBuilder::emitMakeCompositeFromScalar(IRType* type, IRInst* scalarValue)
4166+
{
4167+
switch (type->getOp())
4168+
{
4169+
case kIROp_VectorType: return emitMakeVectorFromScalar(type, scalarValue);
4170+
case kIROp_MatrixType: return emitMakeMatrixFromScalar(type, scalarValue);
4171+
case kIROp_ArrayType: return emitMakeArrayFromElement(type, scalarValue);
4172+
default: SLANG_UNEXPECTED("unhandled composite type"); UNREACHABLE_RETURN(nullptr);
4173+
}
4174+
}
4175+
41654176
IRInst* IRBuilder::emitMatrixReshape(IRType* type, IRInst* inst)
41664177
{
41674178
return emitIntrinsicInst(type, kIROp_MatrixReshape, 1, &inst);

tests/autodiff-dstdlib/dstdlib-abs.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
22
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
3-
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
3+
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type
44

55
//TEST_INPUT:ubuffer(data=[0 0 0 0 0 0 0 0 0 0 0 0], stride=4):out,name=outputBuffer
66
RWStructuredBuffer<float> outputBuffer;

tests/autodiff/matrix-arithmetic-fwd.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -output-using-type
12
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
23
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
3-
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
44

55
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
66
RWStructuredBuffer<float> outputBuffer;

tests/autodiff/reverse-loop-checkpoint-test.slang

+1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
//TEST(compute):COMPARE_COMPUTE_EX:-dx12 -compute -shaderobj -output-using-type
22
//TEST(compute):COMPARE_COMPUTE_EX:-cuda -compute -shaderobj -output-using-type
33
//TEST(compute):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
4+
//TEST(compute):COMPARE_COMPUTE_EX:-wgpu -compute -shaderobj -output-using-type
45
//TEST:SIMPLE(filecheck=CHECK): -target hlsl -profile cs_5_0 -entry computeMain -line-directive-mode none
56
//DISABLE_TEST:SIMPLE(filecheck=CHK):-target glsl -stage compute -entry computeMain -report-checkpoint-intermediates
67

tests/bugs/nested-switch.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
//TEST(compute):COMPARE_COMPUTE: -shaderobj
44
//TEST(compute):COMPARE_COMPUTE:-vk -shaderobj
55
//TEST(compute):COMPARE_COMPUTE:-cpu -shaderobj
6-
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
6+
//TEST(compute):COMPARE_COMPUTE:-wgpu
77

88
int test(int t, int r)
99
{

tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang

-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
//TEST:COMPILE: -entry computeMain -stage compute -target callable tests/hlsl-intrinsic/sampler-feedback/compute-sampler-feedback.slang
2-
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
32

43
// Not available on non PS shader
54
// dx.op.writeSamplerFeedback WriteSamplerFeedback

tests/ir/string-literal-hash.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//TEST(compute):COMPARE_COMPUTE: -shaderobj
22
//TEST(compute):COMPARE_COMPUTE: -vk -shaderobj
3-
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
3+
//TEST(compute):COMPARE_COMPUTE:-wgpu
44

55
// Note: disabled on CPU target until we can fill
66
// in a more correct/complete `String` and `getStringHash`

tests/language-feature/constants/constexpr-loop.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
//TEST(compute):COMPARE_COMPUTE_EX:-slang -compute -shaderobj -output-using-type
22
//TEST(compute, vulkan):COMPARE_COMPUTE_EX:-vk -compute -shaderobj -output-using-type
3-
//DISABLE_TEST(compute):COMPARE_COMPUTE_EX:-wgpu
3+
//TEST(compute):COMPARE_COMPUTE_EX: -wgpu -compute -output-using-type
44

55
//TEST_INPUT: set g_texture = Texture2D(size=8, content = one)
66
//TEST_INPUT: set g_sampler = Sampler

tests/library/linked.spirv

-816 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)