@@ -60,6 +60,15 @@ namespace Slang
60
60
case kIROp_AssociatedType :
61
61
case kIROp_InterfaceType :
62
62
return true ;
63
+ case kIROp_Specialize :
64
+ {
65
+ for (UInt i = 0 ; i < typeInst->getOperandCount (); i++)
66
+ {
67
+ if (isPolymorphicType (typeInst->getOperand (i)))
68
+ return true ;
69
+ }
70
+ return false ;
71
+ }
63
72
default :
64
73
break ;
65
74
}
@@ -124,6 +133,41 @@ namespace Slang
124
133
}
125
134
}
126
135
cloneInstDecorationsAndChildren (&cloneEnv, &sharedBuilderStorage, func, loweredFunc);
136
+
137
+ // If the function returns a generic typed value, we need to turn it
138
+ // into an `out` parameter, since only the caller can allocate space
139
+ // for it.
140
+ auto oldFuncType = cast<IRFuncType>(func->getDataType ());
141
+ if (isPolymorphicType (oldFuncType->getResultType ()))
142
+ {
143
+ builder.setInsertInto (loweredFunc->getFirstBlock ());
144
+ // We defer creation of the returnVal parameter until we see the first
145
+ // `return` instruction, because we can only obtain the cloned return type
146
+ // of this function by checking the type of the cloned return inst.
147
+ IRParam* retValParam = nullptr ;
148
+ // Translate all return insts to `store`s.
149
+ // Those `store`s will be processed and translated into `copy`s when we
150
+ // get to process them via workList.
151
+ for (auto bb : loweredFunc->getBlocks ())
152
+ {
153
+ auto retInst = as<IRReturnVal>(bb->getTerminator ());
154
+ if (!retInst)
155
+ continue ;
156
+ if (!retValParam)
157
+ {
158
+ // Now we have the return type, emit the returnVal parameter.
159
+ // The type of this parameter is still not translated to RawPointer yet,
160
+ // and will be processed together with all the other existing parameters.
161
+ retValParam = builder.emitParamAtHead (
162
+ builder.getOutType (retInst->getVal ()->getDataType ()));
163
+ }
164
+ builder.setInsertBefore (retInst);
165
+ builder.emitStore (retValParam, retInst->getVal ());
166
+ builder.emitReturn ();
167
+ retInst->removeAndDeallocate ();
168
+ }
169
+ }
170
+
127
171
auto block = as<IRBlock>(loweredFunc->getFirstChild ());
128
172
for (auto param : clonedParams)
129
173
{
@@ -139,7 +183,10 @@ namespace Slang
139
183
param = param->getNextInst ())
140
184
{
141
185
// Generic typed parameters have a type that is a param itself.
142
- if (auto rttiParam = as<IRParam>(param->getDataType ()))
186
+ auto paramType = param->getDataType ();
187
+ if (auto ptrType = as<IRPtrTypeBase>(paramType))
188
+ paramType = ptrType->getValueType ();
189
+ if (auto rttiParam = as<IRParam>(paramType))
143
190
{
144
191
SLANG_ASSERT (isPointerOfType (rttiParam->getDataType (), kIROp_RTTIType ));
145
192
// Lower into a function parameter of raw pointer type.
@@ -189,6 +236,14 @@ namespace Slang
189
236
{
190
237
auto loweredParamType = lowerParameterType (builder, paramType);
191
238
translated = translated || (loweredParamType != paramType);
239
+ if (translated && i == 0 )
240
+ {
241
+ // We are translating the return value, this means that
242
+ // the return value must be passed explicitly via an `out` parameter.
243
+ // In this case, the new return value will be `void`, and the
244
+ // translated return value type will be the first parameter type;
245
+ newOperands.add (builder->getVoidType ());
246
+ }
192
247
newOperands.add (loweredParamType);
193
248
}
194
249
}
@@ -382,110 +437,146 @@ namespace Slang
382
437
return result;
383
438
}
384
439
385
- void processInst (IRInst* inst )
440
+ void lowerCall (IRCall* callInst )
386
441
{
387
- if (auto callInst = as<IRCall>(inst))
442
+ // If we see a call(specialize(gFunc, Targs), args),
443
+ // translate it into call(gFunc, args, Targs).
444
+ auto funcOperand = callInst->getOperand (0 );
445
+ IRInst* loweredFunc = nullptr ;
446
+ auto specializeInst = as<IRSpecialize>(funcOperand);
447
+ if (!specializeInst)
448
+ return ;
449
+
450
+ auto funcToSpecialize = specializeInst->getOperand (0 );
451
+ List<IRType*> paramTypes;
452
+ IRFuncType* funcType = nullptr ;
453
+ if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize))
388
454
{
389
- // If we see a call(specialize(gFunc, Targs), args),
390
- // translate it into call(gFunc, args, Targs).
391
- auto funcOperand = callInst->getOperand (0 );
392
- IRInst* loweredFunc = nullptr ;
393
- if (auto specializeInst = as<IRSpecialize>(funcOperand))
455
+ // The callee is a result of witness table lookup, we will only
456
+ // translate the call.
457
+ IRInst* callee = nullptr ;
458
+ auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable ()->getDataType ());
459
+ auto interfaceType = maybeLowerInterfaceType (cast<IRInterfaceType>(witnessTableType->getConformanceType ()));
460
+ for (UInt i = 0 ; i < interfaceType->getOperandCount (); i++)
394
461
{
395
- auto funcToSpecialize = specializeInst->getOperand (0 );
396
- List<IRType*> paramTypes;
397
- if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize))
398
- {
399
- // The callee is a result of witness table lookup, we will only
400
- // translate the call.
401
- IRInst* callee = nullptr ;
402
- auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable ()->getDataType ());
403
- auto interfaceType = maybeLowerInterfaceType (cast<IRInterfaceType>(witnessTableType->getConformanceType ()));
404
- for (UInt i = 0 ; i < interfaceType->getOperandCount (); i++)
405
- {
406
- auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand (i));
407
- if (entry->getRequirementKey () == interfaceLookup->getOperand (1 ))
408
- {
409
- callee = entry->getRequirementVal ();
410
- break ;
411
- }
412
- }
413
- auto funcType = cast<IRFuncType>(callee);
414
- for (UInt i = 0 ; i < funcType->getParamCount (); i++)
415
- paramTypes.add (funcType->getParamType (i));
416
- loweredFunc = funcToSpecialize;
417
- }
418
- else
462
+ auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand (i));
463
+ if (entry->getRequirementKey () == interfaceLookup->getOperand (1 ))
419
464
{
420
- loweredFunc = lowerGenericFunction (specializeInst->getOperand (0 ));
421
- if (loweredFunc == specializeInst->getOperand (0 ))
422
- {
423
- // This is an intrinsic function, don't transform.
424
- return ;
425
- }
426
- for (auto param : as<IRFunc>(loweredFunc)->getParams ())
427
- paramTypes.add (param->getDataType ());
465
+ callee = entry->getRequirementVal ();
466
+ break ;
428
467
}
468
+ }
469
+ funcType = cast<IRFuncType>(callee);
470
+ loweredFunc = funcToSpecialize;
471
+ }
472
+ else
473
+ {
474
+ loweredFunc = lowerGenericFunction (specializeInst->getOperand (0 ));
475
+ if (loweredFunc == specializeInst->getOperand (0 ))
476
+ {
477
+ // This is an intrinsic function, don't transform.
478
+ return ;
479
+ }
480
+ funcType = cast<IRFuncType>(loweredFunc->getDataType ());
481
+ }
429
482
430
- IRBuilder builderStorage;
431
- auto builder = &builderStorage;
432
- builder->sharedBuilder = &sharedBuilderStorage;
433
- builder->setInsertBefore (inst);
434
- List<IRInst*> args;
435
- auto rawPtrType = builder->getRawPointerType ();
436
- for (UInt i = 0 ; i < callInst->getArgCount (); i++)
437
- {
438
- auto arg = callInst->getArg (i);
439
- if (as<IRRawPointerType>(paramTypes[i]) &&
440
- !as<IRRawPointerType>(arg->getDataType ()))
441
- {
442
- // We are calling a generic function that with an argument of
443
- // concrete type. We need to convert this argument to void*.
444
-
445
- // Ideally this should just be a GetElementAddress inst.
446
- // However the current code emitting logic for this instruction
447
- // doesn't truly respect the pointerness and does not produce
448
- // what we needed. For now we use another instruction here
449
- // to keep changes minimal.
450
- arg = builder->emitGetAddress (
451
- rawPtrType,
452
- arg);
453
- }
454
- args.add (arg);
455
- }
456
- for (UInt i = 0 ; i < specializeInst->getArgCount (); i++)
457
- {
458
- auto arg = specializeInst->getArg (i);
459
- // Translate Type arguments into RTTI object.
460
- if (as<IRType>(arg))
461
- {
462
- // We are using a simple type to specialize a callee.
463
- // Generate RTTI for this type.
464
- auto rttiObject = maybeEmitRTTIObject (arg);
465
- arg = builder->emitGetAddress (
466
- builder->getPtrType (builder->getRTTIType ()),
467
- rttiObject);
468
- }
469
- else if (arg->op == kIROp_Specialize )
470
- {
471
- // The type argument used to specialize a callee is itself a
472
- // specialization of some generic type.
473
- // TODO: generate RTTI object for specializations of generic types.
474
- SLANG_UNIMPLEMENTED_X (" RTTI object generation for generic types" );
475
- }
476
- else if (arg->op == kIROp_RTTIObject )
477
- {
478
- // We are inside a generic function and using a generic parameter
479
- // to specialize another callee. The generic parameter of the caller
480
- // has already been translated into an RTTI object, so we just need
481
- // to pass this object down.
482
- }
483
- args.add (arg);
484
- }
485
- auto newCall = builder->emitCallInst (callInst->getFullType (), loweredFunc, args);
486
- callInst->replaceUsesWith (newCall);
487
- callInst->removeAndDeallocate ();
483
+ for (UInt i = 0 ; i < funcType->getParamCount (); i++)
484
+ paramTypes.add (funcType->getParamType (i));
485
+
486
+ IRBuilder builderStorage;
487
+ auto builder = &builderStorage;
488
+ builder->sharedBuilder = &sharedBuilderStorage;
489
+ builder->setInsertBefore (callInst);
490
+
491
+ List<IRInst*> args;
492
+
493
+ // Indicates whether the caller should allocate space for return value.
494
+ // If the lowered callee returns void and this call inst has a type that is not void,
495
+ // then we are calling a transformed function that expects caller allocated return value
496
+ // as the first argument.
497
+ bool shouldCallerAllocateReturnValue = (funcType->getResultType ()->op == kIROp_VoidType &&
498
+ callInst->getDataType () != funcType->getResultType ());
499
+
500
+ IRVar* retVarInst = nullptr ;
501
+ int startParamIndex = 0 ;
502
+ if (shouldCallerAllocateReturnValue)
503
+ {
504
+ // Declare a var for the return value.
505
+ retVarInst = builder->emitVar (callInst->getFullType ());
506
+ args.add (retVarInst);
507
+ startParamIndex = 1 ;
508
+ }
509
+
510
+ for (UInt i = 0 ; i < callInst->getArgCount (); i++)
511
+ {
512
+ auto arg = callInst->getArg (i);
513
+ if (as<IRRawPointerType>(paramTypes[i] + startParamIndex) &&
514
+ !as<IRRawPointerType>(arg->getDataType ()))
515
+ {
516
+ // We are calling a generic function that with an argument of
517
+ // concrete type. We need to convert this argument to void*.
518
+
519
+ // Ideally this should just be a GetElementAddress inst.
520
+ // However the current code emitting logic for this instruction
521
+ // doesn't truly respect the pointerness and does not produce
522
+ // what we needed. For now we use another instruction here
523
+ // to keep changes minimal.
524
+ arg = builder->emitGetAddress (
525
+ builder->getRawPointerType (),
526
+ arg);
527
+ }
528
+ args.add (arg);
529
+ }
530
+ for (UInt i = 0 ; i < specializeInst->getArgCount (); i++)
531
+ {
532
+ auto arg = specializeInst->getArg (i);
533
+ // Translate Type arguments into RTTI object.
534
+ if (as<IRType>(arg))
535
+ {
536
+ // We are using a simple type to specialize a callee.
537
+ // Generate RTTI for this type.
538
+ auto rttiObject = maybeEmitRTTIObject (arg);
539
+ arg = builder->emitGetAddress (
540
+ builder->getPtrType (builder->getRTTIType ()),
541
+ rttiObject);
542
+ }
543
+ else if (arg->op == kIROp_Specialize )
544
+ {
545
+ // The type argument used to specialize a callee is itself a
546
+ // specialization of some generic type.
547
+ // TODO: generate RTTI object for specializations of generic types.
548
+ SLANG_UNIMPLEMENTED_X (" RTTI object generation for generic types" );
488
549
}
550
+ else if (arg->op == kIROp_RTTIObject )
551
+ {
552
+ // We are inside a generic function and using a generic parameter
553
+ // to specialize another callee. The generic parameter of the caller
554
+ // has already been translated into an RTTI object, so we just need
555
+ // to pass this object down.
556
+ }
557
+ args.add (arg);
558
+ }
559
+ auto callInstType = retVarInst ? builder->getVoidType () : callInst->getFullType ();
560
+ auto newCall = builder->emitCallInst (callInstType, loweredFunc, args);
561
+ if (retVarInst)
562
+ {
563
+ auto loadInst = builder->emitLoad (retVarInst);
564
+ callInst->replaceUsesWith (loadInst);
565
+ addToWorkList (loadInst);
566
+ addToWorkList (retVarInst);
567
+ }
568
+ else
569
+ {
570
+ callInst->replaceUsesWith (newCall);
571
+ }
572
+ callInst->removeAndDeallocate ();
573
+ }
574
+
575
+ void processInst (IRInst* inst)
576
+ {
577
+ if (auto callInst = as<IRCall>(inst))
578
+ {
579
+ lowerCall (callInst);
489
580
}
490
581
else if (auto witnessTable = as<IRWitnessTable>(inst))
491
582
{
0 commit comments