27
27
#include " mlir/IR/OpImplementation.h"
28
28
#include " mlir/IR/TypeUtilities.h"
29
29
#include " mlir/Support/LLVM.h"
30
+ #include " llvm/ADT/StringSet.h"
30
31
31
32
using namespace mlir ;
32
33
using namespace mlir ::vector;
@@ -56,7 +57,10 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
56
57
SmallVector<Type, 2 > types;
57
58
Type resultVectorType;
58
59
auto loc = parser.getCurrentLocation ();
59
- if (parser.parseOperand (lhsInfo) || parser.parseComma () ||
60
+ DictionaryAttr dictAttr;
61
+ // TODO(andydavis, ntv) Unify linalg op attribute parsing.
62
+ if (parser.parseAttribute (dictAttr, " _" , result.attributes ) ||
63
+ parser.parseOperand (lhsInfo) || parser.parseComma () ||
60
64
parser.parseOperand (rhsInfo) || parser.parseComma () ||
61
65
parser.parseOperand (accInfo) ||
62
66
parser.parseTrailingOperandList (masksInfo) ||
@@ -68,7 +72,8 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
68
72
parser.resolveOperand (accInfo, resultVectorType, result.operands ) ||
69
73
parser.addTypeToList (resultVectorType, result.types ))
70
74
return failure ();
71
-
75
+ result.attributes .assign (dictAttr.getValue ().begin (),
76
+ dictAttr.getValue ().end ());
72
77
if (masksInfo.empty ())
73
78
return success ();
74
79
if (masksInfo.size () != 2 )
@@ -90,13 +95,23 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
90
95
}
91
96
92
97
static void print (OpAsmPrinter &p, ContractionOp op) {
93
- p << op.getOperationName () << " " << *op.lhs () << " , " << *op.rhs ();
94
- p << " , " << *op.acc ();
98
+ // TODO(andydavis, ntv) Unify printing code with linalg ops.
99
+ auto attrNames = op.getTraitAttrNames ();
100
+ llvm::StringSet<> traitAttrsSet;
101
+ traitAttrsSet.insert (attrNames.begin (), attrNames.end ());
102
+ SmallVector<NamedAttribute, 8 > attrs;
103
+ for (auto attr : op.getAttrs ()) {
104
+ if (traitAttrsSet.count (attr.first .strref ()) > 0 )
105
+ attrs.push_back (attr);
106
+ }
107
+ auto dictAttr = DictionaryAttr::get (attrs, op.getContext ());
108
+ p << op.getOperationName () << " " << dictAttr << " " << *op.lhs () << " , " ;
109
+ p << *op.rhs () << " , " << *op.acc ();
95
110
if (llvm::size (op.masks ()) == 2 ) {
96
111
p << " , " << **op.masks ().begin ();
97
112
p << " , " << **(op.masks ().begin () + 1 );
98
113
}
99
- p.printOptionalAttrDict (op.getAttrs ());
114
+ p.printOptionalAttrDict (op.getAttrs (), attrNames );
100
115
p << " : " << op.lhs ()->getType () << " , " << op.rhs ()->getType () << " into "
101
116
<< op.getResultType ();
102
117
}
@@ -159,6 +174,34 @@ static LogicalResult verify(ContractionOp op) {
159
174
auto rhsType = op.getRhsType ();
160
175
auto accType = op.getAccType ();
161
176
auto resType = op.getResultType ();
177
+
178
+ // Verify that an indexing map was specified for each vector operand.
179
+ if (op.indexing_maps ().size () != 3 )
180
+ return op.emitOpError (" expected an indexing map for each vector operand" );
181
+
182
+ // Verify that each index map has 'numIterators' inputs, no symbols, and
183
+ // that the number of map outputs equals the rank of its associated
184
+ // vector operand.
185
+ unsigned numIterators = op.iterator_types ().getValue ().size ();
186
+ for (auto it : llvm::enumerate (op.indexing_maps ())) {
187
+ auto index = it.index ();
188
+ auto map = it.value ().cast <AffineMapAttr>().getValue ();
189
+ if (map.getNumSymbols () != 0 )
190
+ return op.emitOpError (" expected indexing map " )
191
+ << index << " to have no symbols" ;
192
+ if (map.getNumDims () != numIterators)
193
+ return op.emitOpError (" expected indexing map " )
194
+ << index << " to have " << numIterators << " number of inputs" ;
195
+ auto operandType = op.getOperand (index )->getType ().cast <VectorType>();
196
+ unsigned rank = operandType.getShape ().size ();
197
+ if (map.getNumResults () != rank)
198
+ return op.emitOpError (" expected indexing map " )
199
+ << index << " to have " << rank << " number of outputs" ;
200
+ if (!map.isProjectedPermutation ())
201
+ return op.emitOpError (" expected indexing map " )
202
+ << index << " to be a projected permutation of its inputs" ;
203
+ }
204
+
162
205
auto contractingDimMap = op.getContractingDimMap ();
163
206
auto batchDimMap = op.getBatchDimMap ();
164
207
@@ -198,27 +241,54 @@ static LogicalResult verify(ContractionOp op) {
198
241
return success ();
199
242
}
200
243
201
- static std::vector<std::pair<int64_t , int64_t >> getDimMap (Attribute attr) {
244
+ SmallVector<StringRef, 2 > ContractionOp::getTraitAttrNames () {
245
+ return SmallVector<StringRef, 2 >{" indexing_maps" , " iterator_types" };
246
+ }
247
+
248
+ static int64_t getResultIndex (AffineMap map, AffineExpr targetExpr) {
249
+ for (int64_t i = 0 , e = map.getNumResults (); i < e; ++i)
250
+ if (targetExpr == map.getResult (i))
251
+ return i;
252
+ return -1 ;
253
+ }
254
+
255
+ static std::vector<std::pair<int64_t , int64_t >>
256
+ getDimMap (ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
257
+ StringRef targetIteratorTypeName, MLIRContext *context) {
202
258
std::vector<std::pair<int64_t , int64_t >> dimMap;
203
- auto dimPairs = attr.dyn_cast_or_null <ArrayAttr>();
204
- if (!dimPairs)
205
- return dimMap;
206
- for (auto dimPairAttr : dimPairs) {
207
- auto dimPair = dimPairAttr.cast <ArrayAttr>();
208
- assert (dimPair.size () == 2 );
209
- auto lhsDim = dimPair.begin ()->cast <IntegerAttr>().getInt ();
210
- auto rhsDim = std::prev (dimPair.end ())->cast <IntegerAttr>().getInt ();
211
- dimMap.push_back ({lhsDim, rhsDim});
259
+ for (auto it : llvm::enumerate (iteratorTypes)) {
260
+ auto iteratorTypeName = it.value ().cast <StringAttr>().getValue ();
261
+ if (iteratorTypeName != targetIteratorTypeName)
262
+ continue ;
263
+ // Search lhs/rhs map results for 'targetExpr'.
264
+ auto targetExpr = getAffineDimExpr (it.index (), context);
265
+ int64_t lhsDim = getResultIndex (indexingMaps[0 ], targetExpr);
266
+ int64_t rhsDim = getResultIndex (indexingMaps[1 ], targetExpr);
267
+ if (lhsDim >= 0 && rhsDim >= 0 )
268
+ dimMap.push_back ({lhsDim, rhsDim});
212
269
}
213
270
return dimMap;
214
271
}
215
272
216
273
std::vector<std::pair<int64_t , int64_t >> ContractionOp::getContractingDimMap () {
217
- return getDimMap (getAttr (getContractingDimMapAttrName ()));
274
+ SmallVector<AffineMap, 4 > indexingMaps (getIndexingMaps ());
275
+ return getDimMap (indexingMaps, iterator_types (),
276
+ getReductionIteratorTypeName (), getContext ());
218
277
}
219
278
220
279
std::vector<std::pair<int64_t , int64_t >> ContractionOp::getBatchDimMap () {
221
- return getDimMap (getAttr (getBatchDimMapAttrName ()));
280
+ SmallVector<AffineMap, 4 > indexingMaps (getIndexingMaps ());
281
+ return getDimMap (indexingMaps, iterator_types (),
282
+ getParallelIteratorTypeName (), getContext ());
283
+ }
284
+
285
+ SmallVector<AffineMap, 4 > ContractionOp::getIndexingMaps () {
286
+ SmallVector<AffineMap, 4 > res;
287
+ auto mapAttrs = indexing_maps ().getValue ();
288
+ res.reserve (mapAttrs.size ());
289
+ for (auto mapAttr : mapAttrs)
290
+ res.push_back (mapAttr.cast <AffineMapAttr>().getValue ());
291
+ return res;
222
292
}
223
293
224
294
// ===----------------------------------------------------------------------===//
0 commit comments