Skip to content

Commit 0accc32

Browse files
committed
Make autodiff detection more accurate.
1 parent f427025 commit 0accc32

File tree

2 files changed

+24
-8
lines changed

2 files changed

+24
-8
lines changed

source/slang/slang-ir-link.cpp

+23-7
Original file line numberDiff line numberDiff line change
@@ -1910,8 +1910,22 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram,
19101910
convertAtomicToStorageBuffer(context, bindingToInstMapUnsorted);
19111911
}
19121912

1913+
bool isDiffPairType(IRInst* type)
1914+
{
1915+
for (;;)
1916+
{
1917+
auto type1 = (IRType*)unwrapAttributedType(type);
1918+
auto type2 = unwrapArray(type1);
1919+
if (type2 == type)
1920+
break;
1921+
type = type2;
1922+
}
1923+
return as<IRDifferentialPairTypeBase>(type) != nullptr;
1924+
}
1925+
19131926
bool doesModuleUseAutodiff(IRInst* inst)
19141927
{
1928+
return false;
19151929
switch (inst->getOp())
19161930
{
19171931
case kIROp_Call:
@@ -1929,11 +1943,13 @@ bool doesModuleUseAutodiff(IRInst* inst)
19291943
return false;
19301944
case kIROp_DifferentialPairGetDifferentialUserCode:
19311945
case kIROp_DifferentialPairGetPrimalUserCode:
1932-
case kIROp_DifferentialPairUserCodeType:
1933-
case kIROp_DifferentialPtrPairType:
19341946
case kIROp_DifferentialPtrPairGetPrimal:
19351947
case kIROp_DifferentialPtrPairGetDifferential:
19361948
return true;
1949+
case kIROp_StructField:
1950+
return isDiffPairType(as<IRStructField>(inst)->getFieldType());
1951+
case kIROp_Param:
1952+
return isDiffPairType(inst->getDataType());
19371953
default:
19381954
for (auto child : inst->getChildren())
19391955
{
@@ -2054,13 +2070,13 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
20542070
irModules.addRange(builtinModules);
20552071
ArrayView<IRModule*> userModules = irModules.getArrayView(0, userModuleCount);
20562072

2073+
// Check if any user module uses auto-diff, if so we will need to link
2074+
// additional witnesses and decorations.
20572075
for (IRModule* irModule : userModules)
20582076
{
2059-
// Check if the user module uses auto-diff.
2060-
if (!sharedContext->useAutodiff)
2061-
{
2062-
sharedContext->useAutodiff = doesModuleUseAutodiff(irModule->getModuleInst());
2063-
}
2077+
if (sharedContext->useAutodiff)
2078+
break;
2079+
sharedContext->useAutodiff = doesModuleUseAutodiff(irModule->getModuleInst());
20642080
}
20652081

20662082
auto context = state->getContext();

tools/slang-unit-test/unit-test-compile-benchmark.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ void main(uint3 threadIdx : SV_DispatchThreadID)
7474
sessionDesc.targets = &targetDesc;
7575

7676
auto start = platform::PerformanceCounter::now();
77-
for (int pass = 0; pass < 100; pass++)
77+
for (int pass = 0; pass < 1000; pass++)
7878
{
7979
ComPtr<slang::ISession> session;
8080
SLANG_CHECK(globalSession->createSession(sessionDesc, session.writeRef()) == SLANG_OK);

0 commit comments

Comments
 (0)