Skip to content

Commit f573c15

Browse files
Fix anyvalue marshalling for matrix and 64 bit types. (#5827)
* Fix anyvalue marshalling for matrix types. * Add support for 64bit types marshalling. --------- Co-authored-by: Ellie Hermaszewska <ellieh@nvidia.com>
1 parent f687688 commit f573c15

File tree

4 files changed

+186
-14
lines changed

4 files changed

+186
-14
lines changed

source/slang/slang-emit-spirv.cpp

+16
Original file line numberDiff line numberDiff line change
@@ -3285,6 +3285,19 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
32853285
return nullptr;
32863286
}
32873287

3288+
SpvInst* emitMakeUInt64(SpvInstParent* parent, IRInst* inst)
3289+
{
3290+
IRBuilder builder(inst);
3291+
builder.setInsertBefore(inst);
3292+
auto vec = emitOpCompositeConstruct(
3293+
parent,
3294+
nullptr,
3295+
builder.getVectorType(builder.getUIntType(), 2),
3296+
inst->getOperand(0),
3297+
inst->getOperand(1));
3298+
return emitOpBitcast(parent, inst, inst->getDataType(), vec);
3299+
}
3300+
32883301
// The instructions that appear inside the basic blocks of
32893302
// functions are what we will call "local" instructions.
32903303
//
@@ -3391,6 +3404,9 @@ struct SPIRVEmitContext : public SourceEmitterBase, public SPIRVEmitSharedContex
33913404
case kIROp_BitCast:
33923405
result = emitOpBitcast(parent, inst, inst->getDataType(), inst->getOperand(0));
33933406
break;
3407+
case kIROp_MakeUInt64:
3408+
result = emitMakeUInt64(parent, inst);
3409+
break;
33943410
case kIROp_Add:
33953411
case kIROp_Sub:
33963412
case kIROp_Mul:

source/slang/slang-ir-any-value-marshalling.cpp

+96-14
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,12 @@ struct AnyValueMarshallingContext
103103
intraFieldOffset = 0;
104104
}
105105
}
106+
void ensureOffsetAt8ByteBoundary()
107+
{
108+
ensureOffsetAt4ByteBoundary();
109+
if ((fieldOffset & 1) != 0)
110+
fieldOffset++;
111+
}
106112
void ensureOffsetAt2ByteBoundary()
107113
{
108114
if (intraFieldOffset == 0)
@@ -146,6 +152,7 @@ struct AnyValueMarshallingContext
146152
case kIROp_BoolType:
147153
case kIROp_IntPtrType:
148154
case kIROp_UIntPtrType:
155+
case kIROp_PtrType:
149156
context->marshalBasicType(builder, dataType, concreteTypedVar);
150157
break;
151158
case kIROp_VectorType:
@@ -166,17 +173,36 @@ struct AnyValueMarshallingContext
166173
auto matrixType = static_cast<IRMatrixType*>(dataType);
167174
auto colCount = getIntVal(matrixType->getColumnCount());
168175
auto rowCount = getIntVal(matrixType->getRowCount());
169-
for (IRIntegerValue i = 0; i < colCount; i++)
176+
if (getIntVal(matrixType->getLayout()) == SLANG_MATRIX_LAYOUT_COLUMN_MAJOR)
170177
{
171-
auto col = builder->emitElementAddress(
172-
concreteTypedVar,
173-
builder->getIntValue(builder->getIntType(), i));
174-
for (IRIntegerValue j = 0; j < rowCount; j++)
178+
for (IRIntegerValue i = 0; i < colCount; i++)
179+
{
180+
for (IRIntegerValue j = 0; j < rowCount; j++)
181+
{
182+
auto row = builder->emitElementAddress(
183+
concreteTypedVar,
184+
builder->getIntValue(builder->getIntType(), j));
185+
auto element = builder->emitElementAddress(
186+
row,
187+
builder->getIntValue(builder->getIntType(), i));
188+
emitMarshallingCode(builder, context, element);
189+
}
190+
}
191+
}
192+
else
193+
{
194+
for (IRIntegerValue i = 0; i < rowCount; i++)
175195
{
176-
auto element = builder->emitElementAddress(
177-
col,
178-
builder->getIntValue(builder->getIntType(), j));
179-
emitMarshallingCode(builder, context, element);
196+
auto row = builder->emitElementAddress(
197+
concreteTypedVar,
198+
builder->getIntValue(builder->getIntType(), i));
199+
for (IRIntegerValue j = 0; j < colCount; j++)
200+
{
201+
auto element = builder->emitElementAddress(
202+
row,
203+
builder->getIntValue(builder->getIntType(), j));
204+
emitMarshallingCode(builder, context, element);
205+
}
180206
}
181207
}
182208
break;
@@ -348,11 +374,39 @@ struct AnyValueMarshallingContext
348374
case kIROp_UInt64Type:
349375
case kIROp_Int64Type:
350376
case kIROp_DoubleType:
377+
case kIROp_PtrType:
351378
#if SLANG_PTR_IS_64
352379
case kIROp_UIntPtrType:
353380
case kIROp_IntPtrType:
354381
#endif
355-
SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
382+
ensureOffsetAt8ByteBoundary();
383+
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
384+
{
385+
auto srcVal = builder->emitLoad(concreteVar);
386+
auto dstVal = builder->emitBitCast(builder->getUInt64Type(), srcVal);
387+
auto lowBits = builder->emitCast(builder->getUIntType(), dstVal);
388+
auto highBits = builder->emitShr(
389+
builder->getUInt64Type(),
390+
dstVal,
391+
builder->getIntValue(builder->getIntType(), 32));
392+
highBits = builder->emitCast(builder->getUIntType(), highBits);
393+
394+
auto dstAddr = builder->emitFieldAddress(
395+
uintPtrType,
396+
anyValueVar,
397+
anyValInfo->fieldKeys[fieldOffset]);
398+
builder->emitStore(dstAddr, lowBits);
399+
fieldOffset++;
400+
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
401+
{
402+
dstAddr = builder->emitFieldAddress(
403+
uintPtrType,
404+
anyValueVar,
405+
anyValInfo->fieldKeys[fieldOffset]);
406+
builder->emitStore(dstAddr, lowBits);
407+
fieldOffset++;
408+
}
409+
}
356410
break;
357411
default:
358412
SLANG_UNREACHABLE("unknown basic type");
@@ -545,7 +599,34 @@ struct AnyValueMarshallingContext
545599
case kIROp_DoubleType:
546600
case kIROp_Int8Type:
547601
case kIROp_UInt8Type:
548-
SLANG_UNIMPLEMENTED_X("AnyValue type packing for non 32-bit elements");
602+
case kIROp_PtrType:
603+
#if SLANG_PTR_IS_64
604+
case kIROp_IntPtrType:
605+
case kIROp_UIntPtrType:
606+
#endif
607+
ensureOffsetAt8ByteBoundary();
608+
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
609+
{
610+
auto srcAddr = builder->emitFieldAddress(
611+
uintPtrType,
612+
anyValueVar,
613+
anyValInfo->fieldKeys[fieldOffset]);
614+
auto lowBits = builder->emitLoad(srcAddr);
615+
fieldOffset++;
616+
if (fieldOffset < static_cast<uint32_t>(anyValInfo->fieldKeys.getCount()))
617+
{
618+
auto srcAddr1 = builder->emitFieldAddress(
619+
uintPtrType,
620+
anyValueVar,
621+
anyValInfo->fieldKeys[fieldOffset]);
622+
fieldOffset++;
623+
auto highBits = builder->emitLoad(srcAddr1);
624+
auto combinedBits = builder->emitMakeUInt64(lowBits, highBits);
625+
if (dataType->getOp() != kIROp_UInt64Type)
626+
combinedBits = builder->emitBitCast(dataType, combinedBits);
627+
builder->emitStore(concreteVar, combinedBits);
628+
}
629+
}
549630
break;
550631
default:
551632
SLANG_UNREACHABLE("unknown basic type");
@@ -735,7 +816,8 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
735816
case kIROp_UInt64Type:
736817
case kIROp_Int64Type:
737818
case kIROp_DoubleType:
738-
return -1;
819+
case kIROp_PtrType:
820+
return alignUp(offset, 8) + 8;
739821
case kIROp_Int16Type:
740822
case kIROp_UInt16Type:
741823
case kIROp_HalfType:
@@ -762,9 +844,9 @@ SlangInt _getAnyValueSizeRaw(IRType* type, SlangInt offset)
762844
auto elementType = matrixType->getElementType();
763845
auto colCount = getIntVal(matrixType->getColumnCount());
764846
auto rowCount = getIntVal(matrixType->getRowCount());
765-
for (IRIntegerValue i = 0; i < colCount; i++)
847+
for (IRIntegerValue i = 0; i < rowCount; i++)
766848
{
767-
for (IRIntegerValue j = 0; j < rowCount; j++)
849+
for (IRIntegerValue j = 0; j < colCount; j++)
768850
{
769851
offset = _getAnyValueSizeRaw(elementType, offset);
770852
if (offset < 0)
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -dx12 -use-dxil -profile cs_6_1 -output-using-type
2+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -output-using-type
3+
4+
interface IFoo
5+
{
6+
float getVal();
7+
uint64_t getPtrVal();
8+
}
9+
10+
struct Foo : IFoo
11+
{
12+
column_major float3x2 m;
13+
int x;
14+
uint64_t ptr;
15+
float getVal()
16+
{
17+
return m[2][0];
18+
}
19+
uint64_t getPtrVal()
20+
{
21+
return ptr;
22+
}
23+
}
24+
25+
//TEST_INPUT: type_conformance Foo:IFoo = 0
26+
27+
//TEST_INPUT: set gFoo = ubuffer(data=[0 0 0 0 1.0 2.0 3.0 4.0 5.0 6.0 0 0 1 2], stride=4)
28+
RWStructuredBuffer<IFoo> gFoo;
29+
30+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
31+
RWStructuredBuffer<float> outputBuffer;
32+
33+
[numthreads(1,1,1)]
34+
void computeMain()
35+
{
36+
// CHECK: 3.0
37+
outputBuffer[0] = gFoo[0].getVal();
38+
39+
// CHECK: 1.0
40+
outputBuffer[1] = gFoo[0].getPtrVal()&0xFFFFFFFF;
41+
42+
// CHECK: 2.0
43+
outputBuffer[2] = gFoo[0].getPtrVal()>>32;
44+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
//TEST:COMPARE_COMPUTE(filecheck-buffer=CHECK): -output-using-type
2+
3+
interface IFoo
4+
{
5+
float getVal();
6+
}
7+
8+
struct Foo : IFoo
9+
{
10+
column_major float3x2 m;
11+
float getVal()
12+
{
13+
return m[2][0];
14+
}
15+
}
16+
17+
//TEST_INPUT: type_conformance Foo:IFoo = 0
18+
19+
//TEST_INPUT: set gFoo = ubuffer(data=[0 0 0 0 1.0 2.0 3.0 4.0 5.0 6.0], stride=4)
20+
RWStructuredBuffer<IFoo> gFoo;
21+
22+
//TEST_INPUT: set outputBuffer = out ubuffer(data=[0 0 0 0], stride=4)
23+
RWStructuredBuffer<float> outputBuffer;
24+
25+
[numthreads(1,1,1)]
26+
void computeMain()
27+
{
28+
// CHECK: 3.0
29+
outputBuffer[0] = gFoo[0].getVal();
30+
}

0 commit comments

Comments
 (0)