Skip to content

Commit 114c976

Browse files
authored
Create DirectDeclRef when creating Decl to prevent invalid dedup. (shader-slang#5945)
* Create DirectDeclRef when creating Decl to prevent invalid dedup. * Fix test. * fix * update slang-rhi
1 parent 5df3a74 commit 114c976

File tree

6 files changed

+100
-13
lines changed

6 files changed

+100
-13
lines changed

source/slang/slang-ast-base.cpp

+1-9
Original file line numberDiff line numberDiff line change
@@ -23,15 +23,7 @@ void NodeBase::_initDebug(ASTNodeType inAstNodeType, ASTBuilder* inAstBuilder)
2323
}
2424
DeclRefBase* Decl::getDefaultDeclRef()
2525
{
26-
if (auto astBuilder = getCurrentASTBuilder())
27-
{
28-
const Index currentEpoch = astBuilder->getEpoch();
29-
if (currentEpoch != m_defaultDeclRefEpoch || !m_defaultDeclRef)
30-
{
31-
m_defaultDeclRef = astBuilder->getOrCreate<DirectDeclRef>(this);
32-
m_defaultDeclRefEpoch = currentEpoch;
33-
}
34-
}
26+
SLANG_ASSERT(m_defaultDeclRef);
3527
return m_defaultDeclRef;
3628
}
3729

source/slang/slang-ast-base.h

-1
Original file line numberDiff line numberDiff line change
@@ -793,7 +793,6 @@ class Decl : public DeclBase
793793

794794
private:
795795
SLANG_UNREFLECTED DeclRefBase* m_defaultDeclRef = nullptr;
796-
SLANG_UNREFLECTED Index m_defaultDeclRefEpoch = -1;
797796
};
798797

799798
class Expr : public SyntaxNode

source/slang/slang-ast-builder.h

+4-1
Original file line numberDiff line numberDiff line change
@@ -693,7 +693,10 @@ class ASTBuilder : public RefObject
693693
auto val = (Val*)(node);
694694
val->m_resolvedValEpoch = getEpoch();
695695
}
696-
696+
else if (node->getClassInfo().isSubClassOf(*ASTClassInfo::getInfo(Decl::kType)))
697+
{
698+
((Decl*)node)->m_defaultDeclRef = getOrCreate<DirectDeclRef>((Decl*)node);
699+
}
697700
return node;
698701
}
699702

source/slang/slang.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -4598,7 +4598,7 @@ void Module::_processFindDeclsExportSymbolsRec(Decl* decl)
45984598
if (_canExportDeclSymbol(decl->astNodeType))
45994599
{
46004600
// It's a reference to a declaration in another module, so first get the symbol name.
4601-
String mangledName = getMangledName(getASTBuilder(), decl);
4601+
String mangledName = getMangledName(getCurrentASTBuilder(), decl);
46024602

46034603
Index index = Index(m_mangledExportPool.add(mangledName));
46044604

source/slang/slang.natvis

+2-1
Original file line numberDiff line numberDiff line change
@@ -426,7 +426,8 @@
426426
<Type Name="Slang::Val" Inheritable="true">
427427
<DisplayString Optional="true" Condition="astNodeType == Slang::ASTNodeType::DeclRefType">DeclRefType#{_debugUID} {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString>
428428
<DisplayString Condition="astNodeType == Slang::ASTNodeType::DeclRefType">DeclRefType {*(Val*)(((Slang::DeclRefType*)this)->m_operands.m_buffer[0].values.nodeOperand)}</DisplayString>
429-
<DisplayString Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">DirectRef {*(Decl*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
429+
<DisplayString Optional="true" Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">DirectRef#{_debugUID} {*(Decl*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
430+
<DisplayString Condition="astNodeType == Slang::ASTNodeType::DirectDeclRef">DirectRef {*(Decl*)m_operands.m_buffer[0].values.nodeOperand}</DisplayString>
430431
<DisplayString Optional="true">{astNodeType,en} #{_debugUID}</DisplayString>
431432
<DisplayString>{astNodeType,en}</DisplayString>
432433

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,92 @@
1+
// unit-test-module-ptr.cpp
2+
3+
#include "core/slang-memory-file-system.h"
4+
#include "slang-com-ptr.h"
5+
#include "slang.h"
6+
#include "unit-test/slang-unit-test.h"
7+
8+
#include <stdio.h>
9+
#include <stdlib.h>
10+
11+
using namespace Slang;
12+
13+
SLANG_UNIT_TEST(modulePtr)
14+
{
15+
const char* testModuleSource = R"(
16+
module test_module;
17+
18+
public void atomicFunc(__ref Atomic<int> ptr) {
19+
ptr.add(1);
20+
}
21+
)";
22+
23+
const char* testSource = R"(
24+
import "test_module";
25+
26+
RWStructuredBuffer<Atomic<int>> input0;
27+
28+
[shader("compute")]
29+
[numthreads(1,1,1)]
30+
void computeMain(uint3 workGroup : SV_GroupID)
31+
{
32+
atomicFunc(input0[0]);
33+
}
34+
)";
35+
ComPtr<ISlangMutableFileSystem> memoryFileSystem =
36+
ComPtr<ISlangMutableFileSystem>(new Slang::MemoryFileSystem());
37+
38+
ComPtr<slang::IGlobalSession> globalSession;
39+
SLANG_CHECK(slang_createGlobalSession(SLANG_API_VERSION, globalSession.writeRef()) == SLANG_OK);
40+
slang::TargetDesc targetDesc = {};
41+
targetDesc.format = SLANG_SPIRV;
42+
targetDesc.profile = globalSession->findProfile("spirv_1_5");
43+
slang::SessionDesc sessionDesc = {};
44+
sessionDesc.targetCount = 1;
45+
sessionDesc.targets = &targetDesc;
46+
sessionDesc.compilerOptionEntryCount = 0;
47+
sessionDesc.fileSystem = memoryFileSystem;
48+
49+
// Precompile test_module to file.
50+
{
51+
ComPtr<slang::ISession> session;
52+
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
53+
54+
ComPtr<slang::IBlob> diagnosticBlob;
55+
auto module = session->loadModuleFromSourceString(
56+
"test_module",
57+
"test_module.slang",
58+
testModuleSource,
59+
diagnosticBlob.writeRef());
60+
SLANG_CHECK(module != nullptr);
61+
62+
ComPtr<slang::IBlob> moduleBlob;
63+
module->serialize(moduleBlob.writeRef());
64+
memoryFileSystem->saveFile(
65+
"test_module.slang-module",
66+
moduleBlob->getBufferPointer(),
67+
moduleBlob->getBufferSize());
68+
}
69+
70+
// compile test.
71+
{
72+
ComPtr<slang::ISession> session;
73+
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);
74+
75+
ComPtr<slang::IBlob> diagnosticBlob;
76+
auto module = session->loadModuleFromSourceString(
77+
"test",
78+
"test.slang",
79+
testSource,
80+
diagnosticBlob.writeRef());
81+
SLANG_CHECK(module != nullptr);
82+
83+
ComPtr<slang::IComponentType> linkedProgram;
84+
module->link(linkedProgram.writeRef());
85+
86+
ComPtr<slang::IBlob> code;
87+
88+
linkedProgram->getTargetCode(0, code.writeRef(), diagnosticBlob.writeRef());
89+
90+
SLANG_CHECK(code->getBufferSize() > 0);
91+
}
92+
}

0 commit comments

Comments
 (0)