diff --git a/source/slang/slang-ir-specialize.cpp b/source/slang/slang-ir-specialize.cpp index 40cd40758a..a9b0d44121 100644 --- a/source/slang/slang-ir-specialize.cpp +++ b/source/slang/slang-ir-specialize.cpp @@ -71,6 +71,42 @@ struct SpecializationContext module->getContainerPool().free(&cleanInsts); } + bool isUnsimplifiedArithmeticInst(IRInst* inst) + { + switch (inst->getOp()) + { + case kIROp_Add: + case kIROp_Sub: + case kIROp_Mul: + case kIROp_Div: + case kIROp_Neg: + case kIROp_Not: + case kIROp_Eql: + case kIROp_Neq: + case kIROp_Leq: + case kIROp_Geq: + case kIROp_Less: + case kIROp_IRem: + case kIROp_FRem: + case kIROp_Greater: + case kIROp_Lsh: + case kIROp_Rsh: + case kIROp_BitAnd: + case kIROp_BitOr: + case kIROp_BitXor: + case kIROp_BitNot: + case kIROp_BitCast: + case kIROp_CastIntToFloat: + case kIROp_CastFloatToInt: + case kIROp_IntCast: + case kIROp_FloatCast: + case kIROp_Select: + return true; + default: + return false; + } + } + // An instruction is then fully specialized if and only // if it is in our set. // @@ -133,6 +169,14 @@ struct SpecializationContext return areAllOperandsFullySpecialized(inst); } + if (isUnsimplifiedArithmeticInst(inst)) + { + // For arithmetic insts, we want to wait for simplification before specialization, + // since different insts can simplify to the same value. + // + return false; + } + // The default case is that a global value is always specialized. if (inst->getParent() == module->getModuleInst()) { @@ -1092,6 +1136,7 @@ struct SpecializationContext { this->changed = true; eliminateDeadCode(module->getModuleInst()); + applySparseConditionalConstantPropagationForGlobalScope(this->module, this->sink); } // Once the work list has gone dry, we should have the invariant diff --git a/source/slang/slang.cpp b/source/slang/slang.cpp index 5ec1996581..efc1c6fd11 100644 --- a/source/slang/slang.cpp +++ b/source/slang/slang.cpp @@ -1493,7 +1493,8 @@ DeclRef Linkage::specializeWithArgTypes( DiagnosticSink* sink) { SemanticsVisitor visitor(getSemanticsForReflection()); - visitor = visitor.withSink(sink); + SemanticsVisitor::ExprLocalScope scope; + visitor = visitor.withSink(sink).withExprLocalScope(&scope); SLANG_AST_BUILDER_RAII(getASTBuilder());