|
| 1 | +//===----------------------------------------------------------------------===// |
| 2 | +// |
| 3 | +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. |
| 4 | +// See https://llvm.org/LICENSE.txt for license information. |
| 5 | +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception |
| 6 | +// |
| 7 | +//===----------------------------------------------------------------------===// |
| 8 | +// |
| 9 | +// This pass performs structural hashing for Synth dialect operations |
| 10 | +// (AIG/MIG). Unlike MLIR's general CSE pass, this is domain-specific to |
| 11 | +// AIG/MIG operations, allowing it to reorder operands based on their |
| 12 | +// structural properties and take inversion flags into account for |
| 13 | +// canonicalization. |
| 14 | +// |
| 15 | +//===----------------------------------------------------------------------===// |
| 16 | + |
| 17 | +#include "circt/Dialect/HW/HWOps.h" |
| 18 | +#include "circt/Dialect/Synth/SynthOps.h" |
| 19 | +#include "circt/Dialect/Synth/Transforms/SynthPasses.h" |
| 20 | +#include "circt/Support/Naming.h" |
| 21 | +#include "circt/Support/UnusedOpPruner.h" |
| 22 | +#include "mlir/Analysis/TopologicalSortUtils.h" |
| 23 | +#include "mlir/IR/BuiltinAttributes.h" |
| 24 | +#include "mlir/IR/Operation.h" |
| 25 | +#include "mlir/IR/Visitors.h" |
| 26 | +#include "mlir/Support/LLVM.h" |
| 27 | +#include "llvm/ADT/DenseMap.h" |
| 28 | +#include "llvm/ADT/DenseMapInfo.h" |
| 29 | +#include "llvm/ADT/PointerIntPair.h" |
| 30 | +#include "llvm/Support/DebugLog.h" |
| 31 | +#include "llvm/Support/LogicalResult.h" |
| 32 | + |
| 33 | +#define DEBUG_TYPE "synth-structural-hash" |
| 34 | + |
| 35 | +namespace circt { |
| 36 | +namespace synth { |
| 37 | +#define GEN_PASS_DEF_STRUCTURALHASH |
| 38 | +#include "circt/Dialect/Synth/Transforms/SynthPasses.h.inc" |
| 39 | +} // namespace synth |
| 40 | +} // namespace circt |
| 41 | + |
| 42 | +using namespace circt; |
| 43 | +using namespace circt::synth; |
| 44 | + |
| 45 | +/// A struct that represents the key used for structural hashing. It contains |
| 46 | +/// the operation name and a sorted vector of pointer-integer pairs, which |
| 47 | +/// represent the inputs to the operation and their inversion status. |
| 48 | +/// This key is used to identify structurally equivalent operations for CSE. |
| 49 | +struct StructuralHashKey { |
| 50 | + OperationName opName; |
| 51 | + llvm::SmallVector<llvm::PointerIntPair<Value, 1>, 3> operandPairs; |
| 52 | + |
| 53 | + /// Constructor. |
| 54 | + StructuralHashKey(OperationName name, |
| 55 | + llvm::SmallVector<llvm::PointerIntPair<Value, 1>, 3> inps) |
| 56 | + : opName(name), operandPairs(std::move(inps)) {} |
| 57 | +}; |
| 58 | + |
| 59 | +// DenseMapInfo specialization for StructuralHashKey |
| 60 | +template <> |
| 61 | +struct llvm::DenseMapInfo<StructuralHashKey> { |
| 62 | + static StructuralHashKey getEmptyKey() { |
| 63 | + return StructuralHashKey(llvm::DenseMapInfo<OperationName>::getEmptyKey(), |
| 64 | + {}); |
| 65 | + } |
| 66 | + |
| 67 | + static StructuralHashKey getTombstoneKey() { |
| 68 | + return StructuralHashKey( |
| 69 | + llvm::DenseMapInfo<OperationName>::getTombstoneKey(), {}); |
| 70 | + } |
| 71 | + |
| 72 | + static unsigned getHashValue(const StructuralHashKey &key) { |
| 73 | + auto hash = hash_value(key.opName); |
| 74 | + for (const auto &operand : key.operandPairs) |
| 75 | + hash = llvm::hash_combine(hash, operand.getOpaqueValue()); |
| 76 | + return static_cast<unsigned>(hash); |
| 77 | + } |
| 78 | + |
| 79 | + static bool isEqual(const StructuralHashKey &lhs, |
| 80 | + const StructuralHashKey &rhs) { |
| 81 | + return llvm::DenseMapInfo<OperationName>::isEqual(lhs.opName, rhs.opName) && |
| 82 | + lhs.operandPairs == rhs.operandPairs; |
| 83 | + } |
| 84 | +}; |
| 85 | + |
| 86 | +namespace { |
| 87 | +/// Pass definition. |
| 88 | +struct StructuralHashPass |
| 89 | + : public impl::StructuralHashBase<StructuralHashPass> { |
| 90 | + void runOnOperation() override; |
| 91 | +}; |
| 92 | +} // namespace |
| 93 | + |
| 94 | +namespace { |
| 95 | +/// The main driver class that implements the structural hashing algorithm. |
| 96 | +/// This class manages the state for value numbering, inversion tracking, |
| 97 | +/// and the hash table for CSE. It processes operations in topological order |
| 98 | +/// and performs operand reordering and inversion propagation for |
| 99 | +/// canonicalization. |
| 100 | +class StructuralHashDriver { |
| 101 | +public: |
| 102 | + StructuralHashDriver() = default; |
| 103 | + void visitOp(Operation *op, ArrayRef<bool> inverted); |
| 104 | + void visitUnaryOp(Operation *op, bool inverted); |
| 105 | + void visitVariadicOp(Operation *op, ArrayRef<bool> inverted); |
| 106 | + uint64_t getNumber(Value v); |
| 107 | + |
| 108 | + /// Runs the structural hashing pass on the given module. |
| 109 | + /// Performs topological sorting, assigns value numbers to arguments, |
| 110 | + /// processes target operations, and cleans up unused operations. |
| 111 | + llvm::LogicalResult run(hw::HWModuleOp op); |
| 112 | + |
| 113 | +private: |
| 114 | + /// Maps values to unique numbers for deterministic operand sorting. |
| 115 | + DenseMap<Value, uint64_t> valueNumber; |
| 116 | + uint64_t constantCounter = 0; |
| 117 | + |
| 118 | + /// Pruner for managing unused operations that may be erased later. |
| 119 | + circt::UnusedOpPruner pruner; |
| 120 | + |
| 121 | + /// Hash table mapping structural keys to canonical operations for CSE. |
| 122 | + DenseMap<StructuralHashKey, Operation *> hashTable; |
| 123 | + |
| 124 | + /// Maps inverted values to their non-inverted equivalents for propagation. |
| 125 | + /// For example, if we have: |
| 126 | + /// ``` |
| 127 | + /// %b = synth.aig.and_inv not %a |
| 128 | + /// %c = synth.aig.and_inv not %b |
| 129 | + /// ``` |
| 130 | + /// Then `inversion[%b] = %a`, and when visiting `%c`, we can query |
| 131 | + /// `inversion[%b]` to directly obtain `%a`. |
| 132 | + DenseMap<Value, Value> inversion; |
| 133 | +}; |
| 134 | +} // namespace |
| 135 | + |
| 136 | +void StructuralHashDriver::visitOp(Operation *op, ArrayRef<bool> inverted) { |
| 137 | + /// Dispatches to the appropriate visitor based on the number of operands. |
| 138 | + /// For unary operations, calls visitUnaryOp; for variadic operations, |
| 139 | + /// calls visitVariadicOp. |
| 140 | + if (op->getNumOperands() == 1) { |
| 141 | + visitUnaryOp(op, inverted[0]); |
| 142 | + return; |
| 143 | + } |
| 144 | + visitVariadicOp(op, inverted); |
| 145 | +} |
| 146 | + |
| 147 | +/// Handles unary operations (single operand). |
| 148 | +/// If not inverted, replaces the operation with its operand. |
| 149 | +/// If inverted, attempts to propagate inversion through the inversion map |
| 150 | +/// or records the inversion for later propagation. |
| 151 | +void StructuralHashDriver::visitUnaryOp(Operation *op, bool inverted) { |
| 152 | + if (!inverted) { |
| 153 | + op->replaceAllUsesWith(ArrayRef<Value>{op->getOperand(0)}); |
| 154 | + op->erase(); |
| 155 | + return; |
| 156 | + } |
| 157 | + // Check if we can propagate inversion through the inversion map. |
| 158 | + auto operand = op->getOperand(0); |
| 159 | + auto it = inversion.find(operand); |
| 160 | + if (it != inversion.end()) { |
| 161 | + // Found, replace the operand with the mapped value |
| 162 | + op->replaceAllUsesWith(ArrayRef<Value>{it->second}); |
| 163 | + op->erase(); |
| 164 | + } else { |
| 165 | + // Not found, insert into the map |
| 166 | + inversion[op->getResult(0)] = operand; |
| 167 | + pruner.eraseLaterIfUnused(op); |
| 168 | + } |
| 169 | +} |
| 170 | + |
| 171 | +/// Computes a structural hash key, sorts operands for canonicalization, |
| 172 | +/// and performs CSE by checking the hash table for equivalent operations. |
| 173 | +void StructuralHashDriver::visitVariadicOp(Operation *op, |
| 174 | + ArrayRef<bool> inverted) { |
| 175 | + |
| 176 | + // Compute the structural hash key for the operation. |
| 177 | + StructuralHashKey key(op->getName(), {}); |
| 178 | + for (auto [input, inverted] : llvm::zip(op->getOperands(), inverted)) { |
| 179 | + bool isInverted = inverted; |
| 180 | + // Check if we can propagate inversion through the inversion map |
| 181 | + auto it = inversion.find(input); |
| 182 | + if (it != inversion.end()) { |
| 183 | + // Found, use the mapped value and flip the inversion status |
| 184 | + input = it->second; |
| 185 | + isInverted = !isInverted; |
| 186 | + } |
| 187 | + |
| 188 | + key.operandPairs.push_back( |
| 189 | + llvm::PointerIntPair<Value, 1>(input, isInverted)); |
| 190 | + // Ensure the operand has a number assigned, otherwise sorting might be |
| 191 | + // non-deterministic. |
| 192 | + (void)getNumber(input); |
| 193 | + } |
| 194 | + |
| 195 | + // Sort operands based on their assigned numbers. |
| 196 | + llvm::sort(key.operandPairs, [&](auto a, auto b) { |
| 197 | + size_t aNum = getNumber(a.getPointer()); |
| 198 | + size_t bNum = getNumber(b.getPointer()); |
| 199 | + if (aNum != bNum) |
| 200 | + return aNum < bNum; |
| 201 | + return a.getInt() < b.getInt(); |
| 202 | + }); |
| 203 | + |
| 204 | + // Insert the key into the hash table. |
| 205 | + auto [it, inserted] = hashTable.try_emplace(key, op); |
| 206 | + if (inserted) { |
| 207 | + // New entry, keep the operation and sort its operands. |
| 208 | + op->setOperands(llvm::to_vector<3>(llvm::map_range( |
| 209 | + key.operandPairs, [](auto p) { return p.getPointer(); }))); |
| 210 | + SmallVector<bool, 3> newInversion( |
| 211 | + llvm::map_range(key.operandPairs, [](auto p) { return p.getInt(); })); |
| 212 | + op->setAttr("inverted", |
| 213 | + mlir::DenseBoolArrayAttr::get(op->getContext(), newInversion)); |
| 214 | + // Assign a number to the result for future sorting. |
| 215 | + (void)getNumber(op->getResult(0)); |
| 216 | + } else { |
| 217 | + LDBG() << "Structural Hash: Replacing " << *op << " with " << *(it->second) |
| 218 | + << "\n"; |
| 219 | + // Existing entry, replace all uses and erase the operation. |
| 220 | + // Propagate namehints. |
| 221 | + auto name = circt::chooseName(op, it->second); |
| 222 | + if (name && !it->second->hasAttr("sv.namehint")) |
| 223 | + it->second->setAttr("sv.namehint", name); |
| 224 | + op->replaceAllUsesWith(it->second); |
| 225 | + op->erase(); |
| 226 | + } |
| 227 | +} |
| 228 | + |
| 229 | +/// Assigns or retrieves a unique number for a value. Used for deterministic |
| 230 | +/// operand sorting. |
| 231 | +uint64_t StructuralHashDriver::getNumber(Value v) { |
| 232 | + auto it = valueNumber.find(v); |
| 233 | + if (it != valueNumber.end()) |
| 234 | + return it->second; |
| 235 | + |
| 236 | + // Assign a new number. Constants get high numbers to make constants are |
| 237 | + // pushed to the back. |
| 238 | + if (auto *op = v.getDefiningOp(); |
| 239 | + op && op->hasTrait<mlir::OpTrait::ConstantLike>()) { |
| 240 | + auto [it, inserted] = valueNumber.try_emplace( |
| 241 | + v, std::numeric_limits<uint64_t>::max() - constantCounter++); |
| 242 | + return it->second; |
| 243 | + } |
| 244 | + |
| 245 | + return valueNumber.try_emplace(v, valueNumber.size() - constantCounter) |
| 246 | + .first->second; |
| 247 | +} |
| 248 | + |
| 249 | +llvm::LogicalResult StructuralHashDriver::run(hw::HWModuleOp moduleOp) { |
| 250 | + auto isOperationReady = [&](Value value, Operation *op) -> bool { |
| 251 | + // Otherthan target ops, all other ops are always ready. |
| 252 | + return !isa<circt::synth::aig::AndInverterOp, |
| 253 | + circt::synth::mig::MajorityInverterOp>(op); |
| 254 | + }; |
| 255 | + |
| 256 | + if (!mlir::sortTopologically(moduleOp.getBodyBlock(), isOperationReady)) |
| 257 | + return failure(); |
| 258 | + |
| 259 | + for (auto arg : moduleOp.getBodyBlock()->getArguments()) |
| 260 | + (void)getNumber(arg); |
| 261 | + |
| 262 | + // Process target ops. |
| 263 | + // NOTE: Don't use walk here since the pass currently doesn't handle nested |
| 264 | + // regions. |
| 265 | + for (auto &op : |
| 266 | + llvm::make_early_inc_range(moduleOp.getBodyBlock()->getOperations())) { |
| 267 | + mlir::TypeSwitch<Operation *>(&op) |
| 268 | + .Case<circt::synth::aig::AndInverterOp, |
| 269 | + circt::synth::mig::MajorityInverterOp>([&](auto invertibleOp) { |
| 270 | + visitOp(invertibleOp, invertibleOp.getInverted()); |
| 271 | + }) |
| 272 | + .Default([&](Operation *op) {}); |
| 273 | + } |
| 274 | + |
| 275 | + pruner.eraseNow(); |
| 276 | + return mlir::success(); |
| 277 | +} |
| 278 | + |
| 279 | +void StructuralHashPass::runOnOperation() { |
| 280 | + auto topOp = getOperation(); |
| 281 | + StructuralHashDriver driver; |
| 282 | + if (failed(driver.run(topOp))) |
| 283 | + return signalPassFailure(); |
| 284 | +} |
0 commit comments