Skip to content

Commit a1e79e4

Browse files
authored
Fix the cuda left-hand swizzle issue (shader-slang#3538) (shader-slang#3691)
1 parent cc2a879 commit a1e79e4

File tree

4 files changed

+92
-26
lines changed

4 files changed

+92
-26
lines changed

source/slang/slang-emit-c-like.cpp

+83-21
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,9 @@
3434

3535
namespace Slang {
3636

37+
bool isCPUTarget(TargetRequest* targetReq);
38+
bool isCUDATarget(TargetRequest* targetReq);
39+
3740
struct CLikeSourceEmitter::ComputeEmitActionsContext
3841
{
3942
IRInst* moduleInst;
@@ -352,6 +355,43 @@ void CLikeSourceEmitter::_emitType(IRType* type, DeclaratorInfo* declarator)
352355
}
353356
}
354357

358+
void CLikeSourceEmitter::_emitSwizzleStorePerElement(IRInst* inst)
359+
{
360+
auto subscriptOuter = getInfo(EmitOp::General);
361+
auto subscriptPrec = getInfo(EmitOp::Postfix);
362+
363+
auto ii = cast<IRSwizzledStore>(inst);
364+
365+
UInt elementCount = ii->getElementCount();
366+
UInt dstIndex = 0;
367+
for (UInt ee = 0; ee < elementCount; ++ee)
368+
{
369+
bool needCloseSubscript = maybeEmitParens(subscriptOuter, subscriptPrec);
370+
371+
emitDereferenceOperand(ii->getDest(), leftSide(subscriptOuter, subscriptPrec));
372+
m_writer->emit(".");
373+
374+
IRInst* irElementIndex = ii->getElementIndex(ee);
375+
SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit);
376+
377+
IRConstant* irConst = (IRConstant*)irElementIndex;
378+
379+
UInt elementIndex = (UInt)irConst->value.intVal;
380+
SLANG_RELEASE_ASSERT(elementIndex < 4);
381+
382+
char const* kComponents[] = { "x", "y", "z", "w" };
383+
m_writer->emit(kComponents[elementIndex]);
384+
385+
maybeCloseParens(needCloseSubscript);
386+
387+
m_writer->emit(" = ");
388+
emitOperand(ii->getSource(), getInfo(EmitOp::General));
389+
m_writer->emit(".");
390+
m_writer->emit(kComponents[dstIndex++]);
391+
m_writer->emit(";\n");
392+
}
393+
}
394+
355395
void CLikeSourceEmitter::emitWitnessTable(IRWitnessTable* witnessTable)
356396
{
357397
SLANG_UNUSED(witnessTable);
@@ -1494,6 +1534,19 @@ bool CLikeSourceEmitter::shouldFoldInstIntoUseSites(IRInst* inst)
14941534
}
14951535
}
14961536
}
1537+
1538+
// For cuda and cpu targets don't support swizzle on the left-hand-side
1539+
// variable, e.g. vec4.xy = vec2 is not allowed.
1540+
// Therefore, we don't want to fold the right-hand-side expression.
1541+
// Instead, the right-hand-side expression should be generated as a separable
1542+
// statement and stored in a temporary varible, then assign to the left-hand-side
1543+
// variable per element. E.g. vec4.x = vec2.x; vec4.y = vec2.y.
1544+
if (as<IRSwizzledStore>(user))
1545+
{
1546+
if (isCPUTarget(getTargetReq()) || isCUDATarget(getTargetReq()))
1547+
return false;
1548+
}
1549+
14971550
// We'd like to figure out if it is safe to fold our instruction into `user`
14981551

14991552
// First, let's make sure they are in the same block/parent:
@@ -2760,32 +2813,41 @@ void CLikeSourceEmitter::_emitInst(IRInst* inst)
27602813

27612814
case kIROp_SwizzledStore:
27622815
{
2763-
auto subscriptOuter = getInfo(EmitOp::General);
2764-
auto subscriptPrec = getInfo(EmitOp::Postfix);
2765-
bool needCloseSubscript = maybeEmitParens(subscriptOuter, subscriptPrec);
2816+
// cpp and cuda target don't support swizzle on the left handside, so we
2817+
// have to assign the element one by one.
2818+
if (isCPUTarget(getTargetReq()) || isCUDATarget(getTargetReq()))
2819+
{
2820+
_emitSwizzleStorePerElement(inst);
2821+
}
2822+
else
2823+
{
27662824

2825+
auto subscriptOuter = getInfo(EmitOp::General);
2826+
auto subscriptPrec = getInfo(EmitOp::Postfix);
2827+
bool needCloseSubscript = maybeEmitParens(subscriptOuter, subscriptPrec);
27672828

2768-
auto ii = cast<IRSwizzledStore>(inst);
2769-
emitDereferenceOperand(ii->getDest(), leftSide(subscriptOuter, subscriptPrec));
2770-
m_writer->emit(".");
2771-
UInt elementCount = ii->getElementCount();
2772-
for (UInt ee = 0; ee < elementCount; ++ee)
2773-
{
2774-
IRInst* irElementIndex = ii->getElementIndex(ee);
2775-
SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit);
2776-
IRConstant* irConst = (IRConstant*)irElementIndex;
2829+
auto ii = cast<IRSwizzledStore>(inst);
2830+
emitDereferenceOperand(ii->getDest(), leftSide(subscriptOuter, subscriptPrec));
2831+
m_writer->emit(".");
2832+
UInt elementCount = ii->getElementCount();
2833+
for (UInt ee = 0; ee < elementCount; ++ee)
2834+
{
2835+
IRInst* irElementIndex = ii->getElementIndex(ee);
2836+
SLANG_RELEASE_ASSERT(irElementIndex->getOp() == kIROp_IntLit);
2837+
IRConstant* irConst = (IRConstant*)irElementIndex;
27772838

2778-
UInt elementIndex = (UInt)irConst->value.intVal;
2779-
SLANG_RELEASE_ASSERT(elementIndex < 4);
2839+
UInt elementIndex = (UInt)irConst->value.intVal;
2840+
SLANG_RELEASE_ASSERT(elementIndex < 4);
27802841

2781-
char const* kComponents[] = { "x", "y", "z", "w" };
2782-
m_writer->emit(kComponents[elementIndex]);
2783-
}
2784-
maybeCloseParens(needCloseSubscript);
2842+
char const* kComponents[] = { "x", "y", "z", "w" };
2843+
m_writer->emit(kComponents[elementIndex]);
2844+
}
2845+
maybeCloseParens(needCloseSubscript);
27852846

2786-
m_writer->emit(" = ");
2787-
emitOperand(ii->getSource(), getInfo(EmitOp::General));
2788-
m_writer->emit(";\n");
2847+
m_writer->emit(" = ");
2848+
emitOperand(ii->getSource(), getInfo(EmitOp::General));
2849+
m_writer->emit(";\n");
2850+
}
27892851
}
27902852
break;
27912853

source/slang/slang-emit-c-like.h

+4
Original file line numberDiff line numberDiff line change
@@ -531,6 +531,10 @@ class CLikeSourceEmitter: public SourceEmitterBase
531531
// Sort witnessTable entries according to the order defined in the witnessed interface type.
532532
List<IRWitnessTableEntry*> getSortedWitnessTableEntries(IRWitnessTable* witnessTable);
533533

534+
// Special handling for swizzleStore call, save the right-handside vector to a temporary variable
535+
// first, then assign the corresponding elements to the left-handside vector one by one.
536+
void _emitSwizzleStorePerElement(IRInst* inst);
537+
534538
CodeGenContext* m_codeGenContext = nullptr;
535539
IRModule* m_irModule = nullptr;
536540

tests/compute/half-vector-calc.slang

+1-1
Original file line numberDiff line numberDiff line change
@@ -29,7 +29,7 @@ void computeMain(int3 dispatchThreadID : SV_DispatchThreadID)
2929

3030
// Unary
3131
v2 = +v2.yxwz;
32-
v2 = -v2.zwxy;
32+
v2.xyz = -v2.zwx;
3333

3434
// Scalar vector
3535
v1 = v1 + v2.x;
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
type: float
2-
73.000000
3-
206.500000
4-
539.000000
5-
1070.000000
2+
75.000000
3+
220.500000
4+
565.000000
5+
1108.000000

0 commit comments

Comments
 (0)