@@ -112,16 +112,17 @@ struct DefaultLayoutRulesImpl : SimpleLayoutRulesImpl
112
112
return arrayInfo;
113
113
}
114
114
115
- SimpleLayoutInfo GetVectorLayout (SimpleLayoutInfo elementInfo, size_t elementCount) override
115
+ SimpleLayoutInfo GetVectorLayout (BaseType elementType, SimpleLayoutInfo elementInfo, size_t elementCount) override
116
116
{
117
+ SLANG_UNUSED (elementType);
117
118
SimpleLayoutInfo vectorInfo;
118
119
vectorInfo.kind = elementInfo.kind ;
119
120
vectorInfo.size = elementInfo.size * elementCount;
120
121
vectorInfo.alignment = elementInfo.alignment ;
121
122
return vectorInfo;
122
123
}
123
124
124
- SimpleArrayLayoutInfo GetMatrixLayout (SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) override
125
+ SimpleArrayLayoutInfo GetMatrixLayout (BaseType elementType, SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) override
125
126
{
126
127
// The default behavior here is to lay out a matrix
127
128
// as an array of row vectors (that is row-major).
@@ -131,7 +132,7 @@ struct DefaultLayoutRulesImpl : SimpleLayoutRulesImpl
131
132
// to get layouts with a different convention.
132
133
//
133
134
return GetArrayLayout (
134
- GetVectorLayout (elementInfo, columnCount),
135
+ GetVectorLayout (elementType, elementInfo, columnCount),
135
136
rowCount);
136
137
}
137
138
@@ -204,8 +205,9 @@ struct GLSLBaseLayoutRulesImpl : DefaultLayoutRulesImpl
204
205
{
205
206
typedef DefaultLayoutRulesImpl Super;
206
207
207
- SimpleLayoutInfo GetVectorLayout (SimpleLayoutInfo elementInfo, size_t elementCount) override
208
+ SimpleLayoutInfo GetVectorLayout (BaseType elementType, SimpleLayoutInfo elementInfo, size_t elementCount) override
208
209
{
210
+ SLANG_UNUSED (elementType);
209
211
// The `std140` and `std430` rules require vectors to be aligned to the next power of
210
212
// two up from their size (so a `float2` is 8-byte aligned, and a `float3` is
211
213
// 16-byte aligned).
@@ -224,7 +226,7 @@ struct GLSLBaseLayoutRulesImpl : DefaultLayoutRulesImpl
224
226
return vectorInfo;
225
227
}
226
228
227
- SimpleArrayLayoutInfo GetArrayLayout ( SimpleLayoutInfo elementInfo, LayoutSize elementCount) override
229
+ SimpleArrayLayoutInfo GetArrayLayout (SimpleLayoutInfo elementInfo, LayoutSize elementCount) override
228
230
{
229
231
// The size of an array must be rounded up to be a multiple of its alignment.
230
232
//
@@ -376,7 +378,7 @@ struct CPULayoutRulesImpl : DefaultLayoutRulesImpl
376
378
377
379
// So it is actually a Array<T> on CPU which is a pointer and a size
378
380
info.size = sizeof (void *) * 2 ;
379
- info.alignment = sizeof (void *);
381
+ info.alignment = SLANG_ALIGN_OF (void *);
380
382
381
383
return info;
382
384
}
@@ -398,12 +400,115 @@ struct CPULayoutRulesImpl : DefaultLayoutRulesImpl
398
400
}
399
401
};
400
402
401
- // TODO(JS): Most likely wrong. For layout for CUDA, we'll just do the default to get things up and running
402
403
struct CUDALayoutRulesImpl : DefaultLayoutRulesImpl
403
404
{
404
405
typedef DefaultLayoutRulesImpl Super;
405
- };
406
406
407
+ SimpleLayoutInfo GetScalarLayout (BaseType baseType) override
408
+ {
409
+ switch (baseType)
410
+ {
411
+ case BaseType::Bool:
412
+ {
413
+ // In memory a bool is a byte. BUT when in a vector or matrix it will actually be a int32_t
414
+ return SimpleLayoutInfo (LayoutResourceKind::Uniform, sizeof (uint8_t ), SLANG_ALIGN_OF (uint8_t ));
415
+ }
416
+
417
+ default : return Super::GetScalarLayout (baseType);
418
+ }
419
+ }
420
+
421
+ SimpleArrayLayoutInfo GetArrayLayout (SimpleLayoutInfo elementInfo, LayoutSize elementCount) override
422
+ {
423
+ SLANG_RELEASE_ASSERT (elementInfo.size .isFinite ());
424
+ auto elementSize = elementInfo.size .getFiniteValue ();
425
+ auto elementAlignment = elementInfo.alignment ;
426
+ auto elementStride = RoundToAlignment (elementSize, elementAlignment);
427
+
428
+ if (elementCount.isInfinite ())
429
+ {
430
+ // This is an unsized array, get information for element
431
+ auto info = Super::GetArrayLayout (elementInfo, LayoutSize (1 ));
432
+
433
+ // So it is actually a Array<T> on CUDA which is a pointer and a size
434
+ info.size = sizeof (void *) * 2 ;
435
+ info.alignment = SLANG_ALIGN_OF (void *);
436
+ return info;
437
+ }
438
+
439
+ // An array with no elements will have zero size.
440
+ //
441
+ LayoutSize arraySize = 0 ;
442
+ //
443
+ // Any array with a non-zero number of elements will need
444
+ // to have space for N elements of size `elementSize`, with
445
+ // the constraints that there must be `elementStride` bytes
446
+ // between consecutive elements.
447
+ //
448
+ if (elementCount > 0 )
449
+ {
450
+ // We can think of this as either allocating (N-1)
451
+ // chunks of size `elementStride` (for most of the elements)
452
+ // and then one final chunk of size `elementSize` for
453
+ // the last element, or equivalently as allocating
454
+ // N chunks of size `elementStride` and then "giving back"
455
+ // the final `elementStride - elementSize` bytes.
456
+ //
457
+ arraySize = (elementStride * (elementCount - 1 )) + elementSize;
458
+ }
459
+
460
+ SimpleArrayLayoutInfo arrayInfo;
461
+ arrayInfo.kind = elementInfo.kind ;
462
+ arrayInfo.size = arraySize;
463
+ arrayInfo.alignment = elementAlignment;
464
+ arrayInfo.elementStride = elementStride;
465
+ return arrayInfo;
466
+ }
467
+
468
+ SimpleLayoutInfo GetVectorLayout (BaseType elementType, SimpleLayoutInfo elementInfo, size_t elementCount) override
469
+ {
470
+ // Special case bool
471
+ if (elementType == BaseType::Bool)
472
+ {
473
+ SimpleLayoutInfo fixInfo (elementInfo);
474
+ fixInfo.size = sizeof (int32_t );
475
+ fixInfo.alignment = SLANG_ALIGN_OF (int32_t );
476
+ return GetVectorLayout (BaseType::Int, fixInfo, elementCount);
477
+ }
478
+
479
+ SimpleLayoutInfo vectorInfo;
480
+ vectorInfo.kind = elementInfo.kind ;
481
+ vectorInfo.size = elementInfo.size * elementCount;
482
+ vectorInfo.alignment = elementInfo.alignment ;
483
+
484
+ return vectorInfo;
485
+ }
486
+
487
+ SimpleArrayLayoutInfo GetMatrixLayout (BaseType elementType, SimpleLayoutInfo elementInfo, size_t rowCount, size_t columnCount) override
488
+ {
489
+ // Special case bool
490
+ if (elementType == BaseType::Bool)
491
+ {
492
+ SimpleLayoutInfo fixInfo (elementInfo);
493
+ fixInfo.size = sizeof (int32_t );
494
+ fixInfo.alignment = SLANG_ALIGN_OF (int32_t );
495
+ return GetMatrixLayout (BaseType::Int, fixInfo, rowCount, columnCount);
496
+ }
497
+
498
+ return Super::GetMatrixLayout (elementType, elementInfo, rowCount, columnCount);
499
+ }
500
+
501
+ UniformLayoutInfo BeginStructLayout () override
502
+ {
503
+ return Super::BeginStructLayout ();
504
+ }
505
+
506
+ void EndStructLayout (UniformLayoutInfo* ioStructInfo) override
507
+ {
508
+ // Conform to CUDA/C/C++ size is adjusted to the largest alignment
509
+ ioStructInfo->size = RoundToAlignment (ioStructInfo->size , ioStructInfo->alignment );
510
+ }
511
+ };
407
512
408
513
struct HLSLStructuredBufferLayoutRulesImpl : DefaultLayoutRulesImpl
409
514
{
@@ -436,8 +541,9 @@ struct DefaultVaryingLayoutRulesImpl : DefaultLayoutRulesImpl
436
541
1 );
437
542
}
438
543
439
- SimpleLayoutInfo GetVectorLayout (SimpleLayoutInfo, size_t ) override
544
+ SimpleLayoutInfo GetVectorLayout (BaseType elementType, SimpleLayoutInfo, size_t ) override
440
545
{
546
+ SLANG_UNUSED (elementType);
441
547
// Vectors take up one slot by default
442
548
//
443
549
// TODO: some platforms may decide that vectors of `double` need
@@ -479,8 +585,9 @@ struct GLSLSpecializationConstantLayoutRulesImpl : DefaultLayoutRulesImpl
479
585
1 );
480
586
}
481
587
482
- SimpleLayoutInfo GetVectorLayout (SimpleLayoutInfo, size_t elementCount) override
588
+ SimpleLayoutInfo GetVectorLayout (BaseType elementType, SimpleLayoutInfo, size_t elementCount) override
483
589
{
590
+ SLANG_UNUSED (elementType);
484
591
// GLSL doesn't support vectors of specialization constants,
485
592
// but we will assume that, if supported, they would use one slot per element.
486
593
return SimpleLayoutInfo (
@@ -3052,7 +3159,13 @@ static TypeLayoutResult _createTypeLayout(
3052
3159
context,
3053
3160
elementType);
3054
3161
3055
- auto info = rules->GetVectorLayout (element.info , elementCount);
3162
+ BaseType elementBaseType = BaseType::Void;
3163
+ if (auto elementBasicType = as<BasicExpressionType>(elementType))
3164
+ {
3165
+ elementBaseType = elementBasicType->baseType ;
3166
+ }
3167
+
3168
+ auto info = rules->GetVectorLayout (elementBaseType, element.info , elementCount);
3056
3169
3057
3170
RefPtr<VectorTypeLayout> typeLayout = new VectorTypeLayout ();
3058
3171
typeLayout->type = type;
@@ -3078,6 +3191,12 @@ static TypeLayoutResult _createTypeLayout(
3078
3191
auto elementTypeLayout = elementResult.layout ;
3079
3192
auto elementInfo = elementResult.info ;
3080
3193
3194
+ BaseType elementBaseType = BaseType::Void;
3195
+ if (auto elementBasicType = as<BasicExpressionType>(elementType))
3196
+ {
3197
+ elementBaseType = elementBasicType->baseType ;
3198
+ }
3199
+
3081
3200
// The `GetMatrixLayout` implementation in the layout rules
3082
3201
// currently defaults to assuming row-major layout,
3083
3202
// so if we want column-major layout we achieve it here by
@@ -3092,6 +3211,7 @@ static TypeLayoutResult _createTypeLayout(
3092
3211
layoutMinorCount = tmp;
3093
3212
}
3094
3213
auto info = rules->GetMatrixLayout (
3214
+ elementBaseType,
3095
3215
elementInfo,
3096
3216
layoutMajorCount,
3097
3217
layoutMinorCount);
@@ -3100,6 +3220,7 @@ static TypeLayoutResult _createTypeLayout(
3100
3220
RefPtr<VectorTypeLayout> rowTypeLayout = new VectorTypeLayout ();
3101
3221
3102
3222
auto rowInfo = rules->GetVectorLayout (
3223
+ elementBaseType,
3103
3224
elementInfo,
3104
3225
colCount);
3105
3226
@@ -3680,7 +3801,7 @@ RefPtr<TypeLayout> getSimpleVaryingParameterTypeLayout(
3680
3801
{
3681
3802
auto varyingRuleSet = varyingRules[rr];
3682
3803
auto elementInfo = varyingRuleSet->GetScalarLayout (elementBaseType);
3683
- auto info = varyingRuleSet->GetVectorLayout (elementInfo, elementCount);
3804
+ auto info = varyingRuleSet->GetVectorLayout (elementBaseType, elementInfo, elementCount);
3684
3805
typeLayout->addResourceUsage (info.kind , info.size );
3685
3806
}
3686
3807
@@ -3735,14 +3856,14 @@ RefPtr<TypeLayout> getSimpleVaryingParameterTypeLayout(
3735
3856
auto varyingRuleSet = varyingRules[rr];
3736
3857
auto elementInfo = varyingRuleSet->GetScalarLayout (elementBaseType);
3737
3858
3738
- auto info = varyingRuleSet->GetMatrixLayout (elementInfo, layoutMajorCount, layoutMinorCount);
3859
+ auto info = varyingRuleSet->GetMatrixLayout (elementBaseType, elementInfo, layoutMajorCount, layoutMinorCount);
3739
3860
typeLayout->addResourceUsage (info.kind , info.size );
3740
3861
3741
3862
if (context.matrixLayoutMode == kMatrixLayoutMode_RowMajor )
3742
3863
{
3743
3864
// For row-major matrices only, we can compute an effective
3744
3865
// resource usage for the row type.
3745
- auto rowInfo = varyingRuleSet->GetVectorLayout (elementInfo, colCount);
3866
+ auto rowInfo = varyingRuleSet->GetVectorLayout (elementBaseType, elementInfo, colCount);
3746
3867
rowTypeLayout->addResourceUsage (rowInfo.kind , rowInfo.size );
3747
3868
}
3748
3869
}
0 commit comments