Skip to content

Commit 011d428

Browse files
authored
Cleanup builtin arithmetic interfaces. (shader-slang#3317)
* wip: clean up IArithmetic * wip. * Cleanup builtin arithmetic interfaces. * Fix. * Fixes. * Fix. * Fix. * Fix. --------- Co-authored-by: Yong He <yhe@nvidia.com>
1 parent bfd3f39 commit 011d428

12 files changed

+487
-233
lines changed

source/slang/core.meta.slang

+369-210
Large diffs are not rendered by default.

source/slang/slang-ast-modifier.h

+12
Original file line numberDiff line numberDiff line change
@@ -1075,6 +1075,18 @@ class AnyValueSizeAttribute : public Attribute
10751075
int32_t size;
10761076
};
10771077

1078+
/// This is a stop-gap solution to break overload ambiguity in stdlib.
1079+
/// When there is a function overload ambiguity, the compiler will pick the one with higher rank
1080+
/// specified by this attribute. An overload without this attribute will have a rank of 0.
1081+
/// In the future, we should enhance our type system to take into account the "specialized"-ness
1082+
/// of an overload, such that `T overload1<T:IDerived>()` is more specialized than `T overload2<T:IBase>()`
1083+
/// and preferred during overload resolution.
1084+
class OverloadRankAttribute : public Attribute
1085+
{
1086+
SLANG_AST_CLASS(OverloadRankAttribute)
1087+
int32_t rank;
1088+
};
1089+
10781090
/// An attribute that marks an interface for specialization use only. Any operation that triggers dynamic
10791091
/// dispatch through the interface is a compile-time error.
10801092
class SpecializeAttribute : public Attribute

source/slang/slang-ast-support-types.h

+3
Original file line numberDiff line numberDiff line change
@@ -1482,6 +1482,9 @@ namespace Slang
14821482

14831483
// Cached dictionary for looking up satisfying values.
14841484
SLANG_UNREFLECTED RequirementDictionary m_requirementDictionary;
1485+
1486+
RefPtr<WitnessTable> specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst);
1487+
14851488
};
14861489

14871490
struct SpecializationParam

source/slang/slang-check-expr.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -2225,11 +2225,15 @@ namespace Slang
22252225
{
22262226
// check the base expression first
22272227
expr->functionExpr = CheckTerm(expr->functionExpr);
2228+
2229+
auto treatAsDifferentiableExpr = m_treatAsDifferentiableExpr;
2230+
m_treatAsDifferentiableExpr = nullptr;
22282231
// Next check the argument expressions
22292232
for (auto & arg : expr->arguments)
22302233
{
22312234
arg = CheckTerm(arg);
22322235
}
2236+
m_treatAsDifferentiableExpr = treatAsDifferentiableExpr;
22332237

22342238
// If we are in a differentiable function, register differential witness tables involved in
22352239
// this call.

source/slang/slang-check-modifier.cpp

+13
Original file line numberDiff line numberDiff line change
@@ -349,6 +349,19 @@ namespace Slang
349349

350350
anyValueSizeAttr->size = int32_t(value->getValue());
351351
}
352+
else if (auto overloadRankAttr = as<OverloadRankAttribute>(attr))
353+
{
354+
if (attr->args.getCount() != 1)
355+
{
356+
return false;
357+
}
358+
auto rank = checkConstantIntVal(attr->args[0]);
359+
if (rank == nullptr)
360+
{
361+
return false;
362+
}
363+
overloadRankAttr->rank = int32_t(rank->getValue());
364+
}
352365
else if (auto bindingAttr = as<GLSLBindingAttribute>(attr))
353366
{
354367
// This must be vk::binding or gl::binding (as specified in core.meta.slang under vk_binding/gl_binding)

source/slang/slang-check-overload.cpp

+14
Original file line numberDiff line numberDiff line change
@@ -1070,6 +1070,15 @@ namespace Slang
10701070
return 0;
10711071
}
10721072

1073+
int getOverloadRank(DeclRef<Decl> declRef)
1074+
{
1075+
if (!declRef.getDecl())
1076+
return 0;
1077+
if (auto attr = declRef.getDecl()->findModifier<OverloadRankAttribute>())
1078+
return attr->rank;
1079+
return 0;
1080+
}
1081+
10731082
int SemanticsVisitor::CompareOverloadCandidates(
10741083
OverloadCandidate* left,
10751084
OverloadCandidate* right)
@@ -1142,6 +1151,11 @@ namespace Slang
11421151
auto specificityDiff = compareOverloadCandidateSpecificity(left->item, right->item);
11431152
if(specificityDiff)
11441153
return specificityDiff;
1154+
1155+
// If we reach here, we will attempt to use overload rank to break the ties.
1156+
auto overloadRankDiff = getOverloadRank(right->item.declRef) - getOverloadRank(left->item.declRef);
1157+
if (overloadRankDiff)
1158+
return overloadRankDiff;
11451159
}
11461160

11471161
return 0;

source/slang/slang-ir-constexpr.cpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,9 @@ bool isConstExpr(IRInst* value)
5656
case kIROp_FloatLit:
5757
case kIROp_BoolLit:
5858
case kIROp_Func:
59+
case kIROp_StructKey:
60+
case kIROp_WitnessTable:
61+
case kIROp_Generic:
5962
return true;
6063

6164
default:
@@ -136,6 +139,8 @@ bool opCanBeConstExpr(IROp op)
136139
case kIROp_GetOptionalValue:
137140
case kIROp_DifferentialPairGetDifferential:
138141
case kIROp_DifferentialPairGetPrimal:
142+
case kIROp_LookupWitness:
143+
case kIROp_Specialize:
139144
// TODO: more cases
140145
return true;
141146

@@ -146,10 +151,8 @@ bool opCanBeConstExpr(IROp op)
146151

147152
bool opCanBeConstExprByForwardPass(IRInst* value)
148153
{
149-
// TODO: realistically need to special-case `call`
150-
// operations here, so that we check whether the
151-
// callee function is fixed/known, and if it is
152-
// whether it has been declared as constant-foldable
154+
// TODO: handle call inst here.
155+
153156
if (value->getOp() == kIROp_Param)
154157
return false;
155158
return opCanBeConstExpr(value->getOp());

source/slang/slang-ir-peephole.cpp

+3-3
Original file line numberDiff line numberDiff line change
@@ -317,7 +317,7 @@ struct PeepholeContext : InstPassBase
317317
}
318318
else
319319
{
320-
changed = tryFoldElementExtractFromUpdateInst(inst);
320+
changed |= tryFoldElementExtractFromUpdateInst(inst);
321321
}
322322
break;
323323
case kIROp_GetElement:
@@ -382,7 +382,7 @@ struct PeepholeContext : InstPassBase
382382
}
383383
else
384384
{
385-
changed = tryFoldElementExtractFromUpdateInst(inst);
385+
changed |= tryFoldElementExtractFromUpdateInst(inst);
386386
}
387387
break;
388388
case kIROp_UpdateElement:
@@ -806,7 +806,7 @@ struct PeepholeContext : InstPassBase
806806
case kIROp_Div:
807807
case kIROp_And:
808808
case kIROp_Or:
809-
changed = tryOptimizeArithmeticInst(inst);
809+
changed |= tryOptimizeArithmeticInst(inst);
810810
break;
811811
case kIROp_Param:
812812
{

source/slang/slang-syntax.cpp

+17-2
Original file line numberDiff line numberDiff line change
@@ -245,6 +245,22 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
245245
return m_obj.as<WitnessTable>();
246246
}
247247

248+
RefPtr<WitnessTable> WitnessTable::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst)
249+
{
250+
auto newBaseType = baseType->substitute(astBuilder, subst);
251+
auto newWitnessedType = witnessedType->substitute(astBuilder, subst);
252+
if (newBaseType == baseType && newWitnessedType == witnessedType)
253+
return this;
254+
RefPtr<WitnessTable> result = new WitnessTable();
255+
result->baseType = as<Type>(newBaseType);
256+
result->witnessedType = as<Type>(newWitnessedType);
257+
for (auto requirement : m_requirements)
258+
{
259+
auto newRequirement = requirement.value.specialize(astBuilder, subst);
260+
result->add(requirement.key, newRequirement);
261+
}
262+
return result;
263+
}
248264

249265
RequirementWitness RequirementWitness::specialize(ASTBuilder* astBuilder, SubstitutionSet const& subst)
250266
{
@@ -256,8 +272,7 @@ Index getFilterCountImpl(const ReflectClassInfo& clsInfo, MemberFilterStyle filt
256272
return RequirementWitness();
257273

258274
case RequirementWitness::Flavor::witnessTable:
259-
SLANG_ASSERT(!subst);
260-
return *this;
275+
return RequirementWitness(this->getWitnessTable()->specialize(astBuilder, subst));
261276

262277
case RequirementWitness::Flavor::declRef:
263278
{
+39
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
//TEST(compute):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-slang -compute -shaderobj -output-using-type
2+
//TEST(compute, vulkan):COMPARE_COMPUTE_EX(filecheck-buffer=CHECK):-vk -compute -shaderobj -output-using-type
3+
4+
//TEST_INPUT:ubuffer(data=[0 0 0 0 0], stride=4):out,name=outputBuffer
5+
RWStructuredBuffer<float> outputBuffer;
6+
7+
interface IFoo : IDifferentiable
8+
{
9+
[Differentiable]
10+
__init(Differential v);
11+
}
12+
13+
struct Impl : IFoo
14+
{
15+
float x;
16+
17+
[Differentiable]
18+
__init(Differential v)
19+
{
20+
x = v.x;
21+
}
22+
}
23+
24+
[Differentiable]
25+
float test(float x)
26+
{
27+
Impl.Differential v0 = { x };
28+
var v1 = Impl(v0);
29+
return v1.x * v1.x;
30+
}
31+
32+
[numthreads(1,1,1)]
33+
void computeMain(uint tid : SV_DispatchThreadID)
34+
{
35+
var p = diffPair(3.0, 0.0);
36+
bwd_diff(test)(p, 1.0);
37+
outputBuffer[tid] = p.d;
38+
// CHECK: 6.0
39+
}

tests/ir/loop-inversion.slang

+4-4
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@ RWStructuredBuffer<int> outputBuffer;
1919
// A standard loop
2020
// CHECK-LABEL: int a_{{.*}}()
2121
// CHECK-NOT: break;
22-
// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
22+
// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
2323
// CHECK: [[i]] + int(1);
2424
// CHECK: if(
2525
// CHECK: break;
@@ -35,7 +35,7 @@ int a()
3535
// A vanilla while loop
3636
// CHECK-LABEL: int b_{{.*}}()
3737
// CHECK-NOT: break;
38-
// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
38+
// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
3939
// CHECK: [[i]] + int(1);
4040
// CHECK: if(
4141
// CHECK: break;
@@ -55,7 +55,7 @@ int b()
5555
// A while loop with a break on the false branch
5656
// CHECK-LABEL: int c_{{.*}}()
5757
// CHECK-NOT: break;
58-
// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
58+
// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
5959
// CHECK: [[i]] + int(1);
6060
// CHECK: if(
6161
// CHECK: break;
@@ -79,7 +79,7 @@ int c()
7979
// A while loop with a break on the true branch
8080
// CHECK-LABEL: int d_{{.*}}()
8181
// CHECK-NOT: break;
82-
// CHECK: int j_{{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
82+
// CHECK: int {{.*}} = j_{{.*}} + [[i:i_[0-9]+]]
8383
// CHECK: [[i]] + int(1);
8484
// CHECK: if(
8585
// CHECK: break;

tests/language-feature/generics/iarray.slang

+2-10
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -shaderobj -output-using-type
22
//TEST(compute):COMPARE_COMPUTE(filecheck-buffer=CHECK): -vk -shaderobj -output-using-type
33

4-
T sum<T:__BuiltinArithmeticType>(IArray<T> array)
4+
T sum<T:IFloat>(IArray<T> array)
55
{
66
T result = T(0);
77
for (int i = 0; i < array.getCount(); i++)
@@ -10,15 +10,7 @@ T sum<T:__BuiltinArithmeticType>(IArray<T> array)
1010
}
1111
return result;
1212
}
13-
vector<T,N> sum<T:__BuiltinArithmeticType, let N:int>(IArray<vector<T,N>> array)
14-
{
15-
vector<T,N> result = vector<T,N>(T(0));
16-
for (int i = 0; i < array.getCount(); i++)
17-
{
18-
result = result + array[i];
19-
}
20-
return result;
21-
}
13+
2214
//TEST_INPUT:ubuffer(data=[0 0 0 0], stride=4):out,name=outputBuffer
2315
RWStructuredBuffer<float> outputBuffer;
2416

0 commit comments

Comments
 (0)