@@ -1910,8 +1910,22 @@ void GLSLReplaceAtomicUint(IRSpecContext* context, TargetProgram* targetProgram,
1910
1910
convertAtomicToStorageBuffer (context, bindingToInstMapUnsorted);
1911
1911
}
1912
1912
1913
+ bool isDiffPairType (IRInst* type)
1914
+ {
1915
+ for (;;)
1916
+ {
1917
+ auto type1 = (IRType*)unwrapAttributedType (type);
1918
+ auto type2 = unwrapArray (type1);
1919
+ if (type2 == type)
1920
+ break ;
1921
+ type = type2;
1922
+ }
1923
+ return as<IRDifferentialPairTypeBase>(type) != nullptr ;
1924
+ }
1925
+
1913
1926
bool doesModuleUseAutodiff (IRInst* inst)
1914
1927
{
1928
+ return false ;
1915
1929
switch (inst->getOp ())
1916
1930
{
1917
1931
case kIROp_Call :
@@ -1929,11 +1943,13 @@ bool doesModuleUseAutodiff(IRInst* inst)
1929
1943
return false ;
1930
1944
case kIROp_DifferentialPairGetDifferentialUserCode :
1931
1945
case kIROp_DifferentialPairGetPrimalUserCode :
1932
- case kIROp_DifferentialPairUserCodeType :
1933
- case kIROp_DifferentialPtrPairType :
1934
1946
case kIROp_DifferentialPtrPairGetPrimal :
1935
1947
case kIROp_DifferentialPtrPairGetDifferential :
1936
1948
return true ;
1949
+ case kIROp_StructField :
1950
+ return isDiffPairType (as<IRStructField>(inst)->getFieldType ());
1951
+ case kIROp_Param :
1952
+ return isDiffPairType (inst->getDataType ());
1937
1953
default :
1938
1954
for (auto child : inst->getChildren ())
1939
1955
{
@@ -2054,13 +2070,13 @@ LinkedIR linkIR(CodeGenContext* codeGenContext)
2054
2070
irModules.addRange (builtinModules);
2055
2071
ArrayView<IRModule*> userModules = irModules.getArrayView (0 , userModuleCount);
2056
2072
2073
+ // Check if any user module uses auto-diff, if so we will need to link
2074
+ // additional witnesses and decorations.
2057
2075
for (IRModule* irModule : userModules)
2058
2076
{
2059
- // Check if the user module uses auto-diff.
2060
- if (!sharedContext->useAutodiff )
2061
- {
2062
- sharedContext->useAutodiff = doesModuleUseAutodiff (irModule->getModuleInst ());
2063
- }
2077
+ if (sharedContext->useAutodiff )
2078
+ break ;
2079
+ sharedContext->useAutodiff = doesModuleUseAutodiff (irModule->getModuleInst ());
2064
2080
}
2065
2081
2066
2082
auto context = state->getContext ();
0 commit comments