Skip to content

Commit 941f070

Browse files
authored
Fix attribute reflection. (#5823)
* Fix attribute reflection. * Fix. * Fix.
1 parent e50aac1 commit 941f070

7 files changed

+123
-33
lines changed

include/slang.h

+29-17
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,7 @@ public: \
17761776
typedef struct SlangReflectionVariableLayout SlangReflectionVariableLayout;
17771777
typedef struct SlangReflectionTypeParameter SlangReflectionTypeParameter;
17781778
typedef struct SlangReflectionUserAttribute SlangReflectionUserAttribute;
1779+
typedef SlangReflectionUserAttribute SlangReflectionAttribute;
17791780
typedef struct SlangReflectionFunction SlangReflectionFunction;
17801781
typedef struct SlangReflectionGeneric SlangReflectionGeneric;
17811782

@@ -2140,46 +2141,48 @@ union GenericArgReflection
21402141
bool boolVal;
21412142
};
21422143

2143-
struct UserAttribute
2144+
struct Attribute
21442145
{
21452146
char const* getName()
21462147
{
2147-
return spReflectionUserAttribute_GetName((SlangReflectionUserAttribute*)this);
2148+
return spReflectionUserAttribute_GetName((SlangReflectionAttribute*)this);
21482149
}
21492150
uint32_t getArgumentCount()
21502151
{
21512152
return (uint32_t)spReflectionUserAttribute_GetArgumentCount(
2152-
(SlangReflectionUserAttribute*)this);
2153+
(SlangReflectionAttribute*)this);
21532154
}
21542155
TypeReflection* getArgumentType(uint32_t index)
21552156
{
21562157
return (TypeReflection*)spReflectionUserAttribute_GetArgumentType(
2157-
(SlangReflectionUserAttribute*)this,
2158+
(SlangReflectionAttribute*)this,
21582159
index);
21592160
}
21602161
SlangResult getArgumentValueInt(uint32_t index, int* value)
21612162
{
21622163
return spReflectionUserAttribute_GetArgumentValueInt(
2163-
(SlangReflectionUserAttribute*)this,
2164+
(SlangReflectionAttribute*)this,
21642165
index,
21652166
value);
21662167
}
21672168
SlangResult getArgumentValueFloat(uint32_t index, float* value)
21682169
{
21692170
return spReflectionUserAttribute_GetArgumentValueFloat(
2170-
(SlangReflectionUserAttribute*)this,
2171+
(SlangReflectionAttribute*)this,
21712172
index,
21722173
value);
21732174
}
21742175
const char* getArgumentValueString(uint32_t index, size_t* outSize)
21752176
{
21762177
return spReflectionUserAttribute_GetArgumentValueString(
2177-
(SlangReflectionUserAttribute*)this,
2178+
(SlangReflectionAttribute*)this,
21782179
index,
21792180
outSize);
21802181
}
21812182
};
21822183

2184+
typedef Attribute UserAttribute;
2185+
21832186
struct TypeReflection
21842187
{
21852188
enum class Kind
@@ -2320,13 +2323,15 @@ struct TypeReflection
23202323
return (UserAttribute*)spReflectionType_GetUserAttribute((SlangReflectionType*)this, index);
23212324
}
23222325

2323-
UserAttribute* findUserAttributeByName(char const* name)
2326+
UserAttribute* findAttributeByName(char const* name)
23242327
{
23252328
return (UserAttribute*)spReflectionType_FindUserAttributeByName(
23262329
(SlangReflectionType*)this,
23272330
name);
23282331
}
23292332

2333+
UserAttribute* findUserAttributeByName(char const* name) { return findAttributeByName(name); }
2334+
23302335
TypeReflection* applySpecializations(GenericReflection* generic)
23312336
{
23322337
return (TypeReflection*)spReflectionType_applySpecializations(
@@ -2777,21 +2782,26 @@ struct VariableReflection
27772782
return spReflectionVariable_GetUserAttributeCount((SlangReflectionVariable*)this);
27782783
}
27792784

2780-
UserAttribute* getUserAttributeByIndex(unsigned int index)
2785+
Attribute* getUserAttributeByIndex(unsigned int index)
27812786
{
27822787
return (UserAttribute*)spReflectionVariable_GetUserAttribute(
27832788
(SlangReflectionVariable*)this,
27842789
index);
27852790
}
27862791

2787-
UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
2792+
Attribute* findAttributeByName(SlangSession* globalSession, char const* name)
27882793
{
27892794
return (UserAttribute*)spReflectionVariable_FindUserAttributeByName(
27902795
(SlangReflectionVariable*)this,
27912796
globalSession,
27922797
name);
27932798
}
27942799

2800+
Attribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
2801+
{
2802+
return findAttributeByName(globalSession, name);
2803+
}
2804+
27952805
bool hasDefaultValue()
27962806
{
27972807
return spReflectionVariable_HasDefaultValue((SlangReflectionVariable*)this);
@@ -2908,20 +2918,22 @@ struct FunctionReflection
29082918
{
29092919
return spReflectionFunction_GetUserAttributeCount((SlangReflectionFunction*)this);
29102920
}
2911-
UserAttribute* getUserAttributeByIndex(unsigned int index)
2921+
Attribute* getUserAttributeByIndex(unsigned int index)
29122922
{
2913-
return (UserAttribute*)spReflectionFunction_GetUserAttribute(
2914-
(SlangReflectionFunction*)this,
2915-
index);
2923+
return (
2924+
Attribute*)spReflectionFunction_GetUserAttribute((SlangReflectionFunction*)this, index);
29162925
}
2917-
UserAttribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
2926+
Attribute* findAttributeByName(SlangSession* globalSession, char const* name)
29182927
{
2919-
return (UserAttribute*)spReflectionFunction_FindUserAttributeByName(
2928+
return (Attribute*)spReflectionFunction_FindUserAttributeByName(
29202929
(SlangReflectionFunction*)this,
29212930
globalSession,
29222931
name);
29232932
}
2924-
2933+
Attribute* findUserAttributeByName(SlangSession* globalSession, char const* name)
2934+
{
2935+
return findAttributeByName(globalSession, name);
2936+
}
29252937
Modifier* findModifier(Modifier::ID id)
29262938
{
29272939
return (Modifier*)spReflectionFunction_FindModifier(

source/slang/slang-check-modifier.cpp

+6-9
Original file line numberDiff line numberDiff line change
@@ -752,18 +752,15 @@ Modifier* SemanticsVisitor::validateAttribute(
752752
{
753753
auto& arg = attr->args[paramIndex];
754754
bool typeChecked = false;
755-
if (auto basicType = as<BasicExpressionType>(paramDecl->getType()))
755+
if (isValidCompileTimeConstantType(paramDecl->getType()))
756756
{
757-
if (basicType->getBaseType() == BaseType::Int)
757+
if (auto cint = checkConstantIntVal(arg))
758758
{
759-
if (auto cint = checkConstantIntVal(arg))
760-
{
761-
for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++)
762-
attr->intArgVals.add(nullptr);
763-
attr->intArgVals[(uint32_t)paramIndex] = cint;
764-
}
765-
typeChecked = true;
759+
for (Index ci = attr->intArgVals.getCount(); ci < paramIndex + 1; ci++)
760+
attr->intArgVals.add(nullptr);
761+
attr->intArgVals[(uint32_t)paramIndex] = cint;
766762
}
763+
typeChecked = true;
767764
}
768765
if (!typeChecked)
769766
{

source/slang/slang-reflection-api.cpp

+6-4
Original file line numberDiff line numberDiff line change
@@ -23,11 +23,11 @@ namespace Slang
2323

2424
// Conversion routines to help with strongly-typed reflection API
2525

26-
static inline UserDefinedAttribute* convert(SlangReflectionUserAttribute* attrib)
26+
static inline Attribute* convert(SlangReflectionUserAttribute* attrib)
2727
{
28-
return (UserDefinedAttribute*)attrib;
28+
return (Attribute*)attrib;
2929
}
30-
static inline SlangReflectionUserAttribute* convert(UserDefinedAttribute* attrib)
30+
static inline SlangReflectionUserAttribute* convert(Attribute* attrib)
3131
{
3232
return (SlangReflectionUserAttribute*)attrib;
3333
}
@@ -154,7 +154,9 @@ static SlangReflectionUserAttribute* findUserAttributeByName(
154154
const char* name)
155155
{
156156
auto nameObj = session->tryGetNameObj(name);
157-
for (auto x : decl->getModifiersOfType<UserDefinedAttribute>())
157+
if (!nameObj)
158+
return nullptr;
159+
for (auto x : decl->getModifiersOfType<Attribute>())
158160
{
159161
if (x->keywordName == nameObj)
160162
return (SlangReflectionUserAttribute*)(x);

tools/gfx/d3d12/d3d12-shader-object-layout.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@ bool ShaderObjectLayoutImpl::isBindingRangeRootParameter(
3939
{
4040
if (auto leafVariable = typeLayout->getBindingRangeLeafVariable(bindingRangeIndex))
4141
{
42-
if (leafVariable->findUserAttributeByName(globalSession, rootParameterAttributeName))
42+
if (leafVariable->findAttributeByName(globalSession, rootParameterAttributeName))
4343
{
4444
isRootParameter = true;
4545
}
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,79 @@
1+
// unit-test-translation-unit-import.cpp
2+
3+
#include "../../source/core/slang-io.h"
4+
#include "../../source/core/slang-process.h"
5+
#include "slang-com-ptr.h"
6+
#include "slang.h"
7+
#include "unit-test/slang-unit-test.h"
8+
9+
#include <stdio.h>
10+
#include <stdlib.h>
11+
12+
using namespace Slang;
13+
14+
// Test that the reflection API provides correct info about attributes.
15+
16+
SLANG_UNIT_TEST(attributeReflection)
17+
{
18+
const char* userSourceBody = R"(
19+
public enum E
20+
{
21+
V0,
22+
V1,
23+
};
24+
25+
[__AttributeUsage(_AttributeTargets.Struct)]
26+
public struct NormalTextureAttribute
27+
{
28+
public E Type;
29+
};
30+
31+
[COM("042BE50B-CB01-4DBB-8367-3A9CDCBE2F49")]
32+
interface IInterface { void f(); }
33+
34+
[NormalTexture(E.V1)]
35+
struct TS {};
36+
)";
37+
String userSource = userSourceBody;
38+
ComPtr<slang::IGlobalSession> globalSession;
39+
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
40+
slang::TargetDesc targetDesc = {};
41+
targetDesc.format = SLANG_HLSL;
42+
targetDesc.profile = globalSession->findProfile("sm_5_0");
43+
slang::SessionDesc sessionDesc = {};
44+
sessionDesc.targetCount = 1;
45+
sessionDesc.targets = &targetDesc;
46+
ComPtr<slang::ISession> session;
47+
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
48+
49+
ComPtr<slang::IBlob> diagnosticBlob;
50+
auto module = session->loadModuleFromSourceString(
51+
"m",
52+
"m.slang",
53+
userSourceBody,
54+
diagnosticBlob.writeRef());
55+
SLANG_CHECK(module != nullptr);
56+
57+
auto reflection = module->getLayout();
58+
59+
auto interfaceType = reflection->findTypeByName("IInterface");
60+
SLANG_CHECK(interfaceType != nullptr);
61+
62+
auto comAttribute = interfaceType->findAttributeByName("COM");
63+
SLANG_CHECK(comAttribute != nullptr);
64+
65+
size_t size = 0;
66+
auto guid = comAttribute->getArgumentValueString(0, &size);
67+
UnownedStringSlice stringSlice = UnownedStringSlice(guid, size);
68+
SLANG_CHECK(stringSlice == "\"042BE50B-CB01-4DBB-8367-3A9CDCBE2F49\"");
69+
70+
auto testType = reflection->findTypeByName("TS");
71+
SLANG_CHECK(testType != nullptr);
72+
73+
auto normalTextureAttribute = testType->findAttributeByName("NormalTexture");
74+
SLANG_CHECK(normalTextureAttribute != nullptr);
75+
76+
int value = 0;
77+
normalTextureAttribute->getArgumentValueInt(0, &value);
78+
SLANG_CHECK(value == 1);
79+
}

tools/slang-unit-test/unit-test-decl-tree-reflection.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -178,7 +178,7 @@ SLANG_UNIT_TEST(declTreeReflection)
178178
SLANG_CHECK(result == SLANG_OK);
179179
SLANG_CHECK(val == 1024);
180180
SLANG_CHECK(
181-
funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") ==
181+
funcReflection->findAttributeByName(globalSession.get(), "MyFuncProperty") ==
182182
userAttribute);
183183
}
184184

tools/slang-unit-test/unit-test-function-reflection.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -108,7 +108,7 @@ SLANG_UNIT_TEST(functionReflection)
108108
SLANG_CHECK(result == SLANG_OK);
109109
SLANG_CHECK(val == 1024);
110110
SLANG_CHECK(
111-
funcReflection->findUserAttributeByName(globalSession.get(), "MyFuncProperty") ==
111+
funcReflection->findAttributeByName(globalSession.get(), "MyFuncProperty") ==
112112
userAttribute);
113113

114114
// Check overloaded method resolution

0 commit comments

Comments
 (0)