Skip to content

Commit e9d5ecb

Browse files
authored
Refactor lower-generics pass into separate subpasses. (shader-slang#1442)
1 parent 723c9b1 commit e9d5ecb

17 files changed

+1015
-662
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
//slang-ir-generics-lowering-context.cpp
2+
3+
#include "slang-ir-generics-lowering-context.h"
4+
5+
#include "slang-ir-layout.h"
6+
7+
namespace Slang
8+
{
9+
bool isPolymorphicType(IRInst* typeInst)
10+
{
11+
if (as<IRParam>(typeInst) && as<IRTypeType>(typeInst->getDataType()))
12+
return true;
13+
switch (typeInst->op)
14+
{
15+
case kIROp_ThisType:
16+
case kIROp_AssociatedType:
17+
case kIROp_InterfaceType:
18+
return true;
19+
case kIROp_Specialize:
20+
{
21+
for (UInt i = 0; i < typeInst->getOperandCount(); i++)
22+
{
23+
if (isPolymorphicType(typeInst->getOperand(i)))
24+
return true;
25+
}
26+
return false;
27+
}
28+
default:
29+
break;
30+
}
31+
if (auto ptrType = as<IRPtrTypeBase>(typeInst))
32+
{
33+
return isPolymorphicType(ptrType->getValueType());
34+
}
35+
return false;
36+
}
37+
38+
bool isTypeValue(IRInst* typeInst)
39+
{
40+
if (typeInst)
41+
{
42+
switch (typeInst->op)
43+
{
44+
case kIROp_TypeType:
45+
return true;
46+
case kIROp_lookup_interface_method:
47+
return typeInst->getDataType()->op == kIROp_TypeKind;
48+
default:
49+
return false;
50+
}
51+
}
52+
return false;
53+
}
54+
55+
IRInst* SharedGenericsLoweringContext::maybeEmitRTTIObject(IRInst* typeInst)
56+
{
57+
IRInst* result = nullptr;
58+
if (mapTypeToRTTIObject.TryGetValue(typeInst, result))
59+
return result;
60+
IRBuilder builderStorage;
61+
auto builder = &builderStorage;
62+
builder->sharedBuilder = &sharedBuilderStorage;
63+
builder->setInsertBefore(typeInst->next);
64+
65+
result = builder->emitMakeRTTIObject(typeInst);
66+
67+
// For now the only type info we encapsualte is type size.
68+
IRSizeAndAlignment sizeAndAlignment;
69+
getNaturalSizeAndAlignment((IRType*)typeInst, &sizeAndAlignment);
70+
builder->addRTTITypeSizeDecoration(result, sizeAndAlignment.size);
71+
72+
// Give a name to the rtti object.
73+
if (auto exportDecoration = typeInst->findDecoration<IRExportDecoration>())
74+
{
75+
String rttiObjName = String(exportDecoration->getMangledName()) + "_rtti";
76+
builder->addExportDecoration(result, rttiObjName.getUnownedSlice());
77+
}
78+
mapTypeToRTTIObject[typeInst] = result;
79+
return result;
80+
}
81+
82+
IRInst* SharedGenericsLoweringContext::findInterfaceRequirementVal(IRInterfaceType* interfaceType, IRInst* requirementKey)
83+
{
84+
if (auto dict = mapInterfaceRequirementKeyValue.TryGetValue(interfaceType))
85+
return (*dict)[requirementKey].GetValue();
86+
_builldInterfaceRequirementMap(interfaceType);
87+
return findInterfaceRequirementVal(interfaceType, requirementKey);
88+
}
89+
90+
void SharedGenericsLoweringContext::_builldInterfaceRequirementMap(IRInterfaceType* interfaceType)
91+
{
92+
mapInterfaceRequirementKeyValue.Add(interfaceType,
93+
Dictionary<IRInst*, IRInst*>());
94+
auto dict = mapInterfaceRequirementKeyValue.TryGetValue(interfaceType);
95+
for (UInt i = 0; i < interfaceType->getOperandCount(); i++)
96+
{
97+
auto entry = cast<IRInterfaceRequirementEntry>(interfaceType->getOperand(i));
98+
(*dict)[entry->getRequirementKey()] = entry->getRequirementVal();
99+
}
100+
}
101+
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,65 @@
1+
// slang-ir-generics-lowering-context.h
2+
#pragma once
3+
4+
#include "slang-ir.h"
5+
#include "slang-ir-insts.h"
6+
7+
namespace Slang
8+
{
9+
struct IRModule;
10+
11+
struct SharedGenericsLoweringContext
12+
{
13+
// For convenience, we will keep a pointer to the module
14+
// we are processing.
15+
IRModule* module;
16+
17+
// RTTI objects for each type used to call a generic function.
18+
Dictionary<IRInst*, IRInst*> mapTypeToRTTIObject;
19+
20+
Dictionary<IRInst*, IRInst*> loweredGenericFunctions;
21+
HashSet<IRInterfaceType*> loweredInterfaceTypes;
22+
23+
// Dictionaries for interface type requirement key-value lookups.
24+
// Used by `findInterfaceRequirementVal`.
25+
Dictionary<IRInterfaceType*, Dictionary<IRInst*, IRInst*>> mapInterfaceRequirementKeyValue;
26+
27+
SharedIRBuilder sharedBuilderStorage;
28+
29+
// We will use a single work list of instructions that need
30+
// to be considered for lowering.
31+
//
32+
List<IRInst*> workList;
33+
HashSet<IRInst*> workListSet;
34+
35+
void addToWorkList(
36+
IRInst* inst)
37+
{
38+
for (auto ii = inst->getParent(); ii; ii = ii->getParent())
39+
{
40+
if (as<IRGeneric>(ii))
41+
return;
42+
}
43+
44+
if (workListSet.Contains(inst))
45+
return;
46+
47+
workList.add(inst);
48+
workListSet.Add(inst);
49+
}
50+
51+
52+
void _builldInterfaceRequirementMap(IRInterfaceType* interfaceType);
53+
54+
IRInst* findInterfaceRequirementVal(IRInterfaceType* interfaceType, IRInst* requirementKey);
55+
56+
// Emits an IRRTTIObject containing type information for a given type.
57+
IRInst* maybeEmitRTTIObject(IRInst* typeInst);
58+
};
59+
60+
bool isPolymorphicType(IRInst* typeInst);
61+
62+
// Returns true if typeInst represents a type and should be lowered into
63+
// Ptr(RTTIType).
64+
bool isTypeValue(IRInst* typeInst);
65+
}

source/slang/slang-ir-inst-defs.h

+4
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,12 @@ INST(Nop, nop, 0, 0)
2525
INST_RANGE(BasicType, VoidType, AfterBaseType)
2626

2727
INST(StringType, String, 0, 0)
28+
2829
INST(RawPointerType, RawPointerType, 0, 0)
2930
INST(RTTIPointerType, RTTIPointerType, 1, 0)
31+
INST(AfterRawPointerTypeBase, AfterRawPointerTypeBase, 0, 0)
32+
INST_RANGE(RawPointerTypeBase, RawPointerType, AfterRawPointerTypeBase)
33+
3034

3135
/* ArrayTypeBase */
3236
INST(ArrayType, Array, 2, 0)

source/slang/slang-ir-insts.h

+5
Original file line numberDiff line numberDiff line change
@@ -1483,6 +1483,11 @@ struct IRWitnessTable : IRInst
14831483
return getOperand(0);
14841484
}
14851485

1486+
void setConformanceType(IRInst* type)
1487+
{
1488+
setOperand(0, type);
1489+
}
1490+
14861491
IR_LEAF_ISA(WitnessTable)
14871492
};
14881493

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,185 @@
1+
// slang-ir-lower-generic-function.cpp
2+
#include "slang-ir-lower-generic-function.h"
3+
#include "slang-ir-generics-lowering-context.h"
4+
5+
namespace Slang
6+
{
7+
struct GenericCallLoweringContext
8+
{
9+
SharedGenericsLoweringContext* sharedContext;
10+
11+
// Translate `callInst` into a call of `newCallee`, and respect the new `funcType`.
12+
// If `funcType` involve lowered generic parameters or return values, this function
13+
// also translates the argument list to match with that.
14+
// If `newCallee` is a lowered generic function, `specializeInst` contains the type
15+
// arguments used to specialize the callee.
16+
void translateCallInst(
17+
IRCall* callInst,
18+
IRFuncType* funcType,
19+
IRInst* newCallee,
20+
IRSpecialize* specializeInst)
21+
{
22+
List<IRType*> paramTypes;
23+
for (UInt i = 0; i < funcType->getParamCount(); i++)
24+
paramTypes.add(funcType->getParamType(i));
25+
26+
IRBuilder builderStorage;
27+
auto builder = &builderStorage;
28+
builder->sharedBuilder = &sharedContext->sharedBuilderStorage;
29+
builder->setInsertBefore(callInst);
30+
31+
List<IRInst*> args;
32+
33+
// Indicates whether the caller should allocate space for return value.
34+
// If the lowered callee returns void and this call inst has a type that is not void,
35+
// then we are calling a transformed function that expects caller allocated return value
36+
// as the first argument.
37+
bool shouldCallerAllocateReturnValue = (funcType->getResultType()->op == kIROp_VoidType &&
38+
callInst->getDataType() != funcType->getResultType());
39+
40+
IRVar* retVarInst = nullptr;
41+
int startParamIndex = 0;
42+
if (shouldCallerAllocateReturnValue)
43+
{
44+
// Declare a var for the return value.
45+
retVarInst = builder->emitVar(callInst->getFullType());
46+
args.add(retVarInst);
47+
startParamIndex = 1;
48+
}
49+
50+
for (UInt i = 0; i < callInst->getArgCount(); i++)
51+
{
52+
auto arg = callInst->getArg(i);
53+
if (as<IRRawPointerTypeBase>(paramTypes[i] + startParamIndex) &&
54+
!as<IRRawPointerTypeBase>(arg->getDataType()) &&
55+
!as<IRPtrTypeBase>(arg->getDataType()))
56+
{
57+
// We are calling a generic function that with an argument of
58+
// some concrete value type. We need to convert this argument to void*.
59+
// We do so by defining a local variable, store the SSA value
60+
// in the variable, and use the pointer of this variable as argument.
61+
auto localVar = builder->emitVar(arg->getDataType());
62+
builder->emitStore(localVar, arg);
63+
arg = localVar;
64+
}
65+
args.add(arg);
66+
}
67+
if (specializeInst)
68+
{
69+
for (UInt i = 0; i < specializeInst->getArgCount(); i++)
70+
{
71+
auto arg = specializeInst->getArg(i);
72+
// Translate Type arguments into RTTI object.
73+
if (as<IRType>(arg))
74+
{
75+
// We are using a simple type to specialize a callee.
76+
// Generate RTTI for this type.
77+
auto rttiObject = sharedContext->maybeEmitRTTIObject(arg);
78+
arg = builder->emitGetAddress(
79+
builder->getPtrType(builder->getRTTIType()),
80+
rttiObject);
81+
}
82+
else if (arg->op == kIROp_Specialize)
83+
{
84+
// The type argument used to specialize a callee is itself a
85+
// specialization of some generic type.
86+
// TODO: generate RTTI object for specializations of generic types.
87+
SLANG_UNIMPLEMENTED_X("RTTI object generation for generic types");
88+
}
89+
else if (arg->op == kIROp_RTTIObject)
90+
{
91+
// We are inside a generic function and using a generic parameter
92+
// to specialize another callee. The generic parameter of the caller
93+
// has already been translated into an RTTI object, so we just need
94+
// to pass this object down.
95+
}
96+
args.add(arg);
97+
}
98+
}
99+
auto callInstType = retVarInst ? builder->getVoidType() : callInst->getFullType();
100+
auto newCall = builder->emitCallInst(callInstType, newCallee, args);
101+
if (retVarInst)
102+
{
103+
auto loadInst = builder->emitLoad(retVarInst);
104+
callInst->replaceUsesWith(loadInst);
105+
}
106+
else
107+
{
108+
callInst->replaceUsesWith(newCall);
109+
}
110+
callInst->removeAndDeallocate();
111+
}
112+
113+
void lowerCallToSpecializedFunc(IRCall* callInst, IRSpecialize* specializeInst)
114+
{
115+
// If we see a call(specialize(gFunc, Targs), args),
116+
// translate it into call(gFunc, args, Targs).
117+
auto loweredFunc = specializeInst->getBase();
118+
// All callees should have already been lowered in lower-generic-functions pass.
119+
// For intrinsic generic functions, they are left as is, and we also need to ignore
120+
// them here.
121+
if (loweredFunc->op == kIROp_Generic)
122+
{
123+
// This is an intrinsic function, don't transform.
124+
return;
125+
}
126+
IRFuncType* funcType = cast<IRFuncType>(loweredFunc->getDataType());
127+
translateCallInst(callInst, funcType, loweredFunc, specializeInst);
128+
}
129+
130+
void lowerCall(IRCall* callInst)
131+
{
132+
if (auto specializeInst = as<IRSpecialize>(callInst->getCallee()))
133+
lowerCallToSpecializedFunc(callInst, specializeInst);
134+
}
135+
136+
void processInst(IRInst* inst)
137+
{
138+
if (auto callInst = as<IRCall>(inst))
139+
{
140+
lowerCall(callInst);
141+
}
142+
}
143+
144+
void processModule()
145+
{
146+
// We start by initializing our shared IR building state,
147+
// since we will re-use that state for any code we
148+
// generate along the way.
149+
//
150+
SharedIRBuilder* sharedBuilder = &sharedContext->sharedBuilderStorage;
151+
sharedBuilder->module = sharedContext->module;
152+
sharedBuilder->session = sharedContext->module->session;
153+
154+
sharedContext->addToWorkList(sharedContext->module->getModuleInst());
155+
156+
while (sharedContext->workList.getCount() != 0)
157+
{
158+
// We will then iterate until our work list goes dry.
159+
//
160+
while (sharedContext->workList.getCount() != 0)
161+
{
162+
IRInst* inst = sharedContext->workList.getLast();
163+
164+
sharedContext->workList.removeLast();
165+
sharedContext->workListSet.Remove(inst);
166+
167+
processInst(inst);
168+
169+
for (auto child = inst->getLastChild(); child; child = child->getPrevInst())
170+
{
171+
sharedContext->addToWorkList(child);
172+
}
173+
}
174+
}
175+
}
176+
};
177+
178+
void lowerGenericCalls(SharedGenericsLoweringContext* sharedContext)
179+
{
180+
GenericCallLoweringContext context;
181+
context.sharedContext = sharedContext;
182+
context.processModule();
183+
}
184+
185+
}
+13
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
// slang-ir-lower-generic-call.h
2+
#pragma once
3+
4+
namespace Slang
5+
{
6+
struct SharedGenericsLoweringContext;
7+
8+
/// Lower generic and interface-based code to ordinary types and functions using
9+
/// dynamic dispatch mechanisms.
10+
void lowerGenericCalls(
11+
SharedGenericsLoweringContext* sharedContext);
12+
13+
}

0 commit comments

Comments
 (0)