Skip to content

Commit 48f26ef

Browse files
authored
Dynamic code gen for functions returning generic types. (shader-slang#1439)
* Dynamic code gen for functions returning generic types. * Add expected test result.
1 parent 249f48d commit 48f26ef

6 files changed

+281
-98
lines changed

source/slang/slang-ir-insts.h

+2
Original file line numberDiff line numberDiff line change
@@ -1956,6 +1956,8 @@ struct IRBuilder
19561956
IRType* type);
19571957
IRParam* emitParam(
19581958
IRType* type);
1959+
IRParam* emitParamAtHead(
1960+
IRType* type);
19591961

19601962
IRVar* emitVar(
19611963
IRType* type);

source/slang/slang-ir-lower-generics.cpp

+189-98
Original file line numberDiff line numberDiff line change
@@ -60,6 +60,15 @@ namespace Slang
6060
case kIROp_AssociatedType:
6161
case kIROp_InterfaceType:
6262
return true;
63+
case kIROp_Specialize:
64+
{
65+
for (UInt i = 0; i < typeInst->getOperandCount(); i++)
66+
{
67+
if (isPolymorphicType(typeInst->getOperand(i)))
68+
return true;
69+
}
70+
return false;
71+
}
6372
default:
6473
break;
6574
}
@@ -124,6 +133,41 @@ namespace Slang
124133
}
125134
}
126135
cloneInstDecorationsAndChildren(&cloneEnv, &sharedBuilderStorage, func, loweredFunc);
136+
137+
// If the function returns a generic typed value, we need to turn it
138+
// into an `out` parameter, since only the caller can allocate space
139+
// for it.
140+
auto oldFuncType = cast<IRFuncType>(func->getDataType());
141+
if (isPolymorphicType(oldFuncType->getResultType()))
142+
{
143+
builder.setInsertInto(loweredFunc->getFirstBlock());
144+
// We defer creation of the returnVal parameter until we see the first
145+
// `return` instruction, because we can only obtain the cloned return type
146+
// of this function by checking the type of the cloned return inst.
147+
IRParam* retValParam = nullptr;
148+
// Translate all return insts to `store`s.
149+
// Those `store`s will be processed and translated into `copy`s when we
150+
// get to process them via workList.
151+
for (auto bb : loweredFunc->getBlocks())
152+
{
153+
auto retInst = as<IRReturnVal>(bb->getTerminator());
154+
if (!retInst)
155+
continue;
156+
if (!retValParam)
157+
{
158+
// Now we have the return type, emit the returnVal parameter.
159+
// The type of this parameter is still not translated to RawPointer yet,
160+
// and will be processed together with all the other existing parameters.
161+
retValParam = builder.emitParamAtHead(
162+
builder.getOutType(retInst->getVal()->getDataType()));
163+
}
164+
builder.setInsertBefore(retInst);
165+
builder.emitStore(retValParam, retInst->getVal());
166+
builder.emitReturn();
167+
retInst->removeAndDeallocate();
168+
}
169+
}
170+
127171
auto block = as<IRBlock>(loweredFunc->getFirstChild());
128172
for (auto param : clonedParams)
129173
{
@@ -139,7 +183,10 @@ namespace Slang
139183
param = param->getNextInst())
140184
{
141185
// Generic typed parameters have a type that is a param itself.
142-
if (auto rttiParam = as<IRParam>(param->getDataType()))
186+
auto paramType = param->getDataType();
187+
if (auto ptrType = as<IRPtrTypeBase>(paramType))
188+
paramType = ptrType->getValueType();
189+
if (auto rttiParam = as<IRParam>(paramType))
143190
{
144191
SLANG_ASSERT(isPointerOfType(rttiParam->getDataType(), kIROp_RTTIType));
145192
// Lower into a function parameter of raw pointer type.
@@ -189,6 +236,14 @@ namespace Slang
189236
{
190237
auto loweredParamType = lowerParameterType(builder, paramType);
191238
translated = translated || (loweredParamType != paramType);
239+
if (translated && i == 0)
240+
{
241+
// We are translating the return value, this means that
242+
// the return value must be passed explicitly via an `out` parameter.
243+
// In this case, the new return value will be `void`, and the
244+
// translated return value type will be the first parameter type;
245+
newOperands.add(builder->getVoidType());
246+
}
192247
newOperands.add(loweredParamType);
193248
}
194249
}
@@ -382,110 +437,146 @@ namespace Slang
382437
return result;
383438
}
384439

385-
void processInst(IRInst* inst)
440+
void lowerCall(IRCall* callInst)
386441
{
387-
if (auto callInst = as<IRCall>(inst))
442+
// If we see a call(specialize(gFunc, Targs), args),
443+
// translate it into call(gFunc, args, Targs).
444+
auto funcOperand = callInst->getOperand(0);
445+
IRInst* loweredFunc = nullptr;
446+
auto specializeInst = as<IRSpecialize>(funcOperand);
447+
if (!specializeInst)
448+
return;
449+
450+
auto funcToSpecialize = specializeInst->getOperand(0);
451+
List<IRType*> paramTypes;
452+
IRFuncType* funcType = nullptr;
453+
if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize))
388454
{
389-
// If we see a call(specialize(gFunc, Targs), args),
390-
// translate it into call(gFunc, args, Targs).
391-
auto funcOperand = callInst->getOperand(0);
392-
IRInst* loweredFunc = nullptr;
393-
if (auto specializeInst = as<IRSpecialize>(funcOperand))
455+
// The callee is a result of witness table lookup, we will only
456+
// translate the call.
457+
IRInst* callee = nullptr;
458+
auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable()->getDataType());
459+
auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType()));
460+
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
394461
{
395-
auto funcToSpecialize = specializeInst->getOperand(0);
396-
List<IRType*> paramTypes;
397-
if (auto interfaceLookup = as<IRLookupWitnessMethod>(funcToSpecialize))
398-
{
399-
// The callee is a result of witness table lookup, we will only
400-
// translate the call.
401-
IRInst* callee = nullptr;
402-
auto witnessTableType = cast<IRWitnessTableType>(interfaceLookup->getWitnessTable()->getDataType());
403-
auto interfaceType = maybeLowerInterfaceType(cast<IRInterfaceType>(witnessTableType->getConformanceType()));
404-
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
405-
{
406-
auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
407-
if (entry->getRequirementKey() == interfaceLookup->getOperand(1))
408-
{
409-
callee = entry->getRequirementVal();
410-
break;
411-
}
412-
}
413-
auto funcType = cast<IRFuncType>(callee);
414-
for (UInt i = 0; i < funcType->getParamCount(); i++)
415-
paramTypes.add(funcType->getParamType(i));
416-
loweredFunc = funcToSpecialize;
417-
}
418-
else
462+
auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
463+
if (entry->getRequirementKey() == interfaceLookup->getOperand(1))
419464
{
420-
loweredFunc = lowerGenericFunction(specializeInst->getOperand(0));
421-
if (loweredFunc == specializeInst->getOperand(0))
422-
{
423-
// This is an intrinsic function, don't transform.
424-
return;
425-
}
426-
for (auto param : as<IRFunc>(loweredFunc)->getParams())
427-
paramTypes.add(param->getDataType());
465+
callee = entry->getRequirementVal();
466+
break;
428467
}
468+
}
469+
funcType = cast<IRFuncType>(callee);
470+
loweredFunc = funcToSpecialize;
471+
}
472+
else
473+
{
474+
loweredFunc = lowerGenericFunction(specializeInst->getOperand(0));
475+
if (loweredFunc == specializeInst->getOperand(0))
476+
{
477+
// This is an intrinsic function, don't transform.
478+
return;
479+
}
480+
funcType = cast<IRFuncType>(loweredFunc->getDataType());
481+
}
429482

430-
IRBuilder builderStorage;
431-
auto builder = &builderStorage;
432-
builder->sharedBuilder = &sharedBuilderStorage;
433-
builder->setInsertBefore(inst);
434-
List<IRInst*> args;
435-
auto rawPtrType = builder->getRawPointerType();
436-
for (UInt i = 0; i < callInst->getArgCount(); i++)
437-
{
438-
auto arg = callInst->getArg(i);
439-
if (as<IRRawPointerType>(paramTypes[i]) &&
440-
!as<IRRawPointerType>(arg->getDataType()))
441-
{
442-
// We are calling a generic function that with an argument of
443-
// concrete type. We need to convert this argument to void*.
444-
445-
// Ideally this should just be a GetElementAddress inst.
446-
// However the current code emitting logic for this instruction
447-
// doesn't truly respect the pointerness and does not produce
448-
// what we needed. For now we use another instruction here
449-
// to keep changes minimal.
450-
arg = builder->emitGetAddress(
451-
rawPtrType,
452-
arg);
453-
}
454-
args.add(arg);
455-
}
456-
for (UInt i = 0; i < specializeInst->getArgCount(); i++)
457-
{
458-
auto arg = specializeInst->getArg(i);
459-
// Translate Type arguments into RTTI object.
460-
if (as<IRType>(arg))
461-
{
462-
// We are using a simple type to specialize a callee.
463-
// Generate RTTI for this type.
464-
auto rttiObject = maybeEmitRTTIObject(arg);
465-
arg = builder->emitGetAddress(
466-
builder->getPtrType(builder->getRTTIType()),
467-
rttiObject);
468-
}
469-
else if (arg->op == kIROp_Specialize)
470-
{
471-
// The type argument used to specialize a callee is itself a
472-
// specialization of some generic type.
473-
// TODO: generate RTTI object for specializations of generic types.
474-
SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types");
475-
}
476-
else if (arg->op == kIROp_RTTIObject)
477-
{
478-
// We are inside a generic function and using a generic parameter
479-
// to specialize another callee. The generic parameter of the caller
480-
// has already been translated into an RTTI object, so we just need
481-
// to pass this object down.
482-
}
483-
args.add(arg);
484-
}
485-
auto newCall = builder->emitCallInst(callInst->getFullType(), loweredFunc, args);
486-
callInst->replaceUsesWith(newCall);
487-
callInst->removeAndDeallocate();
483+
for (UInt i = 0; i < funcType->getParamCount(); i++)
484+
paramTypes.add(funcType->getParamType(i));
485+
486+
IRBuilder builderStorage;
487+
auto builder = &builderStorage;
488+
builder->sharedBuilder = &sharedBuilderStorage;
489+
builder->setInsertBefore(callInst);
490+
491+
List<IRInst*> args;
492+
493+
// Indicates whether the caller should allocate space for return value.
494+
// If the lowered callee returns void and this call inst has a type that is not void,
495+
// then we are calling a transformed function that expects caller allocated return value
496+
// as the first argument.
497+
bool shouldCallerAllocateReturnValue = (funcType->getResultType()->op == kIROp_VoidType &&
498+
callInst->getDataType() != funcType->getResultType());
499+
500+
IRVar* retVarInst = nullptr;
501+
int startParamIndex = 0;
502+
if (shouldCallerAllocateReturnValue)
503+
{
504+
// Declare a var for the return value.
505+
retVarInst = builder->emitVar(callInst->getFullType());
506+
args.add(retVarInst);
507+
startParamIndex = 1;
508+
}
509+
510+
for (UInt i = 0; i < callInst->getArgCount(); i++)
511+
{
512+
auto arg = callInst->getArg(i);
513+
if (as<IRRawPointerType>(paramTypes[i] + startParamIndex) &&
514+
!as<IRRawPointerType>(arg->getDataType()))
515+
{
516+
// We are calling a generic function that with an argument of
517+
// concrete type. We need to convert this argument to void*.
518+
519+
// Ideally this should just be a GetElementAddress inst.
520+
// However the current code emitting logic for this instruction
521+
// doesn't truly respect the pointerness and does not produce
522+
// what we needed. For now we use another instruction here
523+
// to keep changes minimal.
524+
arg = builder->emitGetAddress(
525+
builder->getRawPointerType(),
526+
arg);
527+
}
528+
args.add(arg);
529+
}
530+
for (UInt i = 0; i < specializeInst->getArgCount(); i++)
531+
{
532+
auto arg = specializeInst->getArg(i);
533+
// Translate Type arguments into RTTI object.
534+
if (as<IRType>(arg))
535+
{
536+
// We are using a simple type to specialize a callee.
537+
// Generate RTTI for this type.
538+
auto rttiObject = maybeEmitRTTIObject(arg);
539+
arg = builder->emitGetAddress(
540+
builder->getPtrType(builder->getRTTIType()),
541+
rttiObject);
542+
}
543+
else if (arg->op == kIROp_Specialize)
544+
{
545+
// The type argument used to specialize a callee is itself a
546+
// specialization of some generic type.
547+
// TODO: generate RTTI object for specializations of generic types.
548+
SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types");
488549
}
550+
else if (arg->op == kIROp_RTTIObject)
551+
{
552+
// We are inside a generic function and using a generic parameter
553+
// to specialize another callee. The generic parameter of the caller
554+
// has already been translated into an RTTI object, so we just need
555+
// to pass this object down.
556+
}
557+
args.add(arg);
558+
}
559+
auto callInstType = retVarInst ? builder->getVoidType() : callInst->getFullType();
560+
auto newCall = builder->emitCallInst(callInstType, loweredFunc, args);
561+
if (retVarInst)
562+
{
563+
auto loadInst = builder->emitLoad(retVarInst);
564+
callInst->replaceUsesWith(loadInst);
565+
addToWorkList(loadInst);
566+
addToWorkList(retVarInst);
567+
}
568+
else
569+
{
570+
callInst->replaceUsesWith(newCall);
571+
}
572+
callInst->removeAndDeallocate();
573+
}
574+
575+
void processInst(IRInst* inst)
576+
{
577+
if (auto callInst = as<IRCall>(inst))
578+
{
579+
lowerCall(callInst);
489580
}
490581
else if (auto witnessTable = as<IRWitnessTable>(inst))
491582
{

source/slang/slang-ir.cpp

+30
Original file line numberDiff line numberDiff line change
@@ -343,6 +343,25 @@ namespace Slang
343343
}
344344
}
345345

346+
// Similar to addParam, but instead of appending `param` to the end
347+
// of the parameter list, this function inserts `param` before the
348+
// head of the list.
349+
void IRBlock::insertParamAtHead(IRParam* param)
350+
{
351+
if (auto firstParam = getFirstParam())
352+
{
353+
param->insertBefore(firstParam);
354+
}
355+
else if (auto firstOrdinary = getFirstOrdinaryInst())
356+
{
357+
param->insertBefore(firstOrdinary);
358+
}
359+
else
360+
{
361+
param->insertAtEnd(this);
362+
}
363+
}
364+
346365
IRInst* IRBlock::getFirstOrdinaryInst()
347366
{
348367
// Find the last parameter (if any) of the block
@@ -3030,6 +3049,17 @@ namespace Slang
30303049
return param;
30313050
}
30323051

3052+
IRParam* IRBuilder::emitParamAtHead(
3053+
IRType* type)
3054+
{
3055+
auto param = createParam(type);
3056+
if (auto bb = getBlock())
3057+
{
3058+
bb->insertParamAtHead(param);
3059+
}
3060+
return param;
3061+
}
3062+
30333063
IRVar* IRBuilder::emitVar(
30343064
IRType* type)
30353065
{

source/slang/slang-ir.h

+1
Original file line numberDiff line numberDiff line change
@@ -822,6 +822,7 @@ struct IRBlock : IRInst
822822
}
823823

824824
void addParam(IRParam* param);
825+
void insertParamAtHead(IRParam* param);
825826

826827
// The "ordinary" instructions come after the parameters
827828
IRInst* getFirstOrdinaryInst();

0 commit comments

Comments
 (0)