@@ -597,20 +597,101 @@ static void unexportNonEmbeddableIR(CodeGenTarget target, IRModule* irModule)
597
597
}
598
598
}
599
599
600
- static void validateMatrixDimensions (DiagnosticSink* sink, IRModule* module)
600
+ static void validateVectorOrMatrixElementType (
601
+ DiagnosticSink* sink,
602
+ SourceLoc sourceLoc,
603
+ IRType* elementType,
604
+ uint32_t allowedWidths,
605
+ const DiagnosticInfo& disallowedElementTypeEncountered)
606
+ {
607
+ if (!isFloatingType (elementType))
608
+ {
609
+ if (isIntegralType (elementType))
610
+ {
611
+ IntInfo info = getIntTypeInfo (elementType);
612
+ if (allowedWidths == 0U )
613
+ {
614
+ sink->diagnose (sourceLoc, disallowedElementTypeEncountered, elementType);
615
+ }
616
+ else
617
+ {
618
+ bool widthAllowed = false ;
619
+ SLANG_ASSERT ((allowedWidths & ~(0xfU << 3 )) == 0U );
620
+ for (uint32_t p = 3U ; p <= 6U ; p++)
621
+ {
622
+ uint32_t width = 1U << p;
623
+ if (!(allowedWidths & width))
624
+ continue ;
625
+ widthAllowed = widthAllowed || (info.width == width);
626
+ }
627
+ if (!widthAllowed)
628
+ {
629
+ sink->diagnose (sourceLoc, disallowedElementTypeEncountered, elementType);
630
+ }
631
+ }
632
+ }
633
+ else if (!as<IRBoolType>(elementType))
634
+ {
635
+ sink->diagnose (sourceLoc, disallowedElementTypeEncountered, elementType);
636
+ }
637
+ }
638
+ }
639
+
640
+ static void validateVectorsAndMatrices (
641
+ DiagnosticSink* sink,
642
+ IRModule* module,
643
+ TargetRequest* targetRequest)
601
644
{
602
645
for (auto globalInst : module->getGlobalInsts ())
603
646
{
604
647
if (auto matrixType = as<IRMatrixType>(globalInst))
605
648
{
606
- auto colCount = as<IRIntLit>(matrixType->getColumnCount ());
607
- auto rowCount = as<IRIntLit>(matrixType->getRowCount ());
608
-
609
- if ((rowCount && (rowCount->getValue () == 1 )) ||
610
- (colCount && (colCount->getValue () == 1 )))
649
+ // Matrices with row/col dimension 1 are only well-supported on D3D targets
650
+ if (!isD3DTarget (targetRequest))
611
651
{
612
- sink->diagnose (matrixType->sourceLoc , Diagnostics::matrixColumnOrRowCountIsOne);
652
+ // Verify that neither row nor col count is 1
653
+ auto colCount = as<IRIntLit>(matrixType->getColumnCount ());
654
+ auto rowCount = as<IRIntLit>(matrixType->getRowCount ());
655
+
656
+ if ((rowCount && (rowCount->getValue () == 1 )) ||
657
+ (colCount && (colCount->getValue () == 1 )))
658
+ {
659
+ sink->diagnose (matrixType->sourceLoc , Diagnostics::matrixColumnOrRowCountIsOne);
660
+ }
613
661
}
662
+
663
+ // Verify that the element type is a floating point type, or an allowed integral type
664
+ auto elementType = matrixType->getElementType ();
665
+ uint32_t allowedWidths = 0U ;
666
+ if (isCPUTarget (targetRequest))
667
+ allowedWidths = 8U | 16U | 32U | 64U ;
668
+ else if (isCUDATarget (targetRequest))
669
+ allowedWidths = 32U | 64U ;
670
+ else if (isD3DTarget (targetRequest))
671
+ allowedWidths = 16U | 32U ;
672
+ validateVectorOrMatrixElementType (
673
+ sink,
674
+ matrixType->sourceLoc ,
675
+ elementType,
676
+ allowedWidths,
677
+ Diagnostics::matrixWithDisallowedElementTypeEncountered);
678
+ }
679
+ else if (auto vectorType = as<IRVectorType>(globalInst))
680
+ {
681
+ // Verify that the element type is a floating point type, or an allowed integral type
682
+ auto elementType = vectorType->getElementType ();
683
+ uint32_t allowedWidths = 0U ;
684
+ if (isWGPUTarget (targetRequest))
685
+ allowedWidths = 32U ;
686
+ else
687
+ allowedWidths = 8U | 16U | 32U | 64U ;
688
+
689
+ validateVectorOrMatrixElementType (
690
+ sink,
691
+ vectorType->sourceLoc ,
692
+ elementType,
693
+ allowedWidths,
694
+ Diagnostics::vectorWithDisallowedElementTypeEncountered);
614
695
}
615
696
}
616
697
}
@@ -1602,9 +1683,8 @@ Result linkAndOptimizeIR(
1602
1683
#endif
1603
1684
validateIRModuleIfEnabled (codeGenContext, irModule);
1604
1685
1605
- // Make sure there are no matrices with 1 row/column, except for D3D targets where it's allowed.
1606
- if (!isD3DTarget (targetRequest))
1607
- validateMatrixDimensions (sink, irModule);
1686
+ // Validate vectors and matrices according to what the target allows
1687
+ validateVectorsAndMatrices (sink, irModule, targetRequest);
1608
1688
1609
1689
// The resource-based specialization pass above
1610
1690
// may create specialized versions of functions, but
0 commit comments