@@ -349,6 +349,138 @@ void CUDASourceEmitter::emitLoopControlDecorationImpl(IRLoopControlDecoration* d
349
349
}
350
350
}
351
351
352
+ static bool _areEquivalent (IRType* a, IRType* b)
353
+ {
354
+ if (a == b)
355
+ {
356
+ return true ;
357
+ }
358
+ if (a->op != b->op )
359
+ {
360
+ return false ;
361
+ }
362
+
363
+ switch (a->op )
364
+ {
365
+ case kIROp_VectorType :
366
+ {
367
+ IRVectorType* vecA = static_cast <IRVectorType*>(a);
368
+ IRVectorType* vecB = static_cast <IRVectorType*>(b);
369
+ return GetIntVal (vecA->getElementCount ()) == GetIntVal (vecB->getElementCount ()) &&
370
+ _areEquivalent (vecA->getElementType (), vecB->getElementType ());
371
+ }
372
+ case kIROp_MatrixType :
373
+ {
374
+ IRMatrixType* matA = static_cast <IRMatrixType*>(a);
375
+ IRMatrixType* matB = static_cast <IRMatrixType*>(b);
376
+ return GetIntVal (matA->getColumnCount ()) == GetIntVal (matB->getColumnCount ()) &&
377
+ GetIntVal (matA->getRowCount ()) == GetIntVal (matB->getRowCount ()) &&
378
+ _areEquivalent (matA->getElementType (), matB->getElementType ());
379
+ }
380
+ default :
381
+ {
382
+ return as<IRBasicType>(a) != nullptr ;
383
+ }
384
+ }
385
+ }
386
+
387
+ void CUDASourceEmitter::_emitInitializerListValue (IRType* dstType, IRInst* value)
388
+ {
389
+ // When constructing a matrix or vector from a single value this is handled by the default path
390
+
391
+ switch (value->op )
392
+ {
393
+ case kIROp_Construct :
394
+ case kIROp_MakeMatrix :
395
+ case kIROp_makeVector :
396
+ {
397
+ IRType* type = value->getDataType ();
398
+
399
+ // If the types are the same, we can can just break down and use
400
+ if (_areEquivalent (dstType, type))
401
+ {
402
+ if (auto vecType = as<IRVectorType>(type))
403
+ {
404
+ if (UInt (GetIntVal (vecType->getElementCount ())) == value->getOperandCount ())
405
+ {
406
+ _emitInitializerList (vecType->getElementType (), value->getOperands (), value->getOperandCount ());
407
+ return ;
408
+ }
409
+ }
410
+ else if (auto matType = as<IRMatrixType>(type))
411
+ {
412
+ const Index colCount = Index (GetIntVal (matType->getColumnCount ()));
413
+ const Index rowCount = Index (GetIntVal (matType->getRowCount ()));
414
+
415
+ // TODO(JS): If num cols = 1, then it *doesn't* actually return a vector.
416
+ // That could be argued is an error because we want swizzling or [] to work.
417
+ IRType* rowType = m_typeSet.addVectorType (matType->getElementType (), int (colCount));
418
+ IRVectorType* rowVectorType = as<IRVectorType>(rowType);
419
+ const Index operandCount = Index (value->getOperandCount ());
420
+
421
+ // Can init, with vectors.
422
+ // For now special case if the rowVectorType is not actually a vector (when elementSize == 1)
423
+ if (operandCount == rowCount || rowVectorType == nullptr )
424
+ {
425
+ // We have to output vectors
426
+
427
+ // Emit the braces for the Matrix struct, contains an row array.
428
+ m_writer->emit (" {\n " );
429
+ m_writer->indent ();
430
+ _emitInitializerList (rowType, value->getOperands (), rowCount);
431
+ m_writer->dedent ();
432
+ m_writer->emit (" \n }" );
433
+ return ;
434
+ }
435
+ else if (operandCount == rowCount * colCount)
436
+ {
437
+ // Handle if all are explicitly defined
438
+ IRType* elementType = matType->getElementType ();
439
+ IRUse* operands = value->getOperands ();
440
+
441
+ // Emit the braces for the Matrix struct, and the array of rows
442
+ m_writer->emit (" {\n " );
443
+ m_writer->indent ();
444
+ m_writer->emit (" {\n " );
445
+ m_writer->indent ();
446
+ for (Index i = 0 ; i < rowCount; ++i)
447
+ {
448
+ if (i != 0 ) m_writer->emit (" , " );
449
+ _emitInitializerList (elementType, operands, colCount);
450
+ operands += colCount;
451
+ }
452
+ m_writer->dedent ();
453
+ m_writer->emit (" \n }" );
454
+ m_writer->dedent ();
455
+ m_writer->emit (" \n }" );
456
+ return ;
457
+ }
458
+ }
459
+ }
460
+
461
+ break ;
462
+ }
463
+ }
464
+
465
+ // All other cases we just use the default emitting - might not work on arrays defined in global scope on CUDA though
466
+ emitOperand (value, getInfo (EmitOp::General));
467
+ }
468
+
469
+ void CUDASourceEmitter::_emitInitializerList (IRType* elementType, IRUse* operands, Index operandCount)
470
+ {
471
+ m_writer->emit (" {\n " );
472
+ m_writer->indent ();
473
+
474
+ for (Index i = 0 ; i < operandCount; ++i)
475
+ {
476
+ if (i != 0 ) m_writer->emit (" , " );
477
+ _emitInitializerListValue (elementType, operands[i].get ());
478
+ }
479
+
480
+ m_writer->dedent ();
481
+ m_writer->emit (" \n }" );
482
+ }
483
+
352
484
bool CUDASourceEmitter::tryEmitInstExprImpl (IRInst* inst, const EmitOpInfo& inOuterPrec)
353
485
{
354
486
switch (inst->op )
@@ -369,6 +501,23 @@ bool CUDASourceEmitter::tryEmitInstExprImpl(IRInst* inst, const EmitOpInfo& inOu
369
501
}
370
502
break ;
371
503
}
504
+ case kIROp_makeArray :
505
+ {
506
+ IRType* dataType = inst->getDataType ();
507
+ IRArrayType* arrayType = as<IRArrayType>(dataType);
508
+
509
+ IRType* elementType = arrayType->getElementType ();
510
+
511
+ // Emit braces for the FixedArray struct.
512
+ m_writer->emit (" {\n " );
513
+ m_writer->indent ();
514
+
515
+ _emitInitializerList (elementType, inst->getOperands (), Index (inst->getOperandCount ()));
516
+
517
+ m_writer->dedent ();
518
+ m_writer->emit (" \n }" );
519
+ return true ;
520
+ }
372
521
default : break ;
373
522
}
374
523
0 commit comments