Skip to content
This repository was archived by the owner on Apr 23, 2021. It is now read-only.

Commit 5614db0

Browse files
andydavis1tensorflower-gardener
authored andcommitted
Update VectorContractionOp to take iterator types and index mapping attributes compatible with linalg ops.
PiperOrigin-RevId: 282412311
1 parent d61eb2d commit 5614db0

File tree

4 files changed

+320
-76
lines changed

4 files changed

+320
-76
lines changed

include/mlir/Dialect/VectorOps/VectorOps.td

+53-33
Original file line numberDiff line numberDiff line change
@@ -46,12 +46,11 @@ class Vector_Op<string mnemonic, list<OpTrait> traits = []> :
4646

4747
// TODO(andydavis, ntv) Add an attribute to specify a different algebra
4848
// with operators other than the current set: {*, +}.
49-
// TODO(andydavis) Consider using AffineMaps to express contracting, batch
50-
// and free dimension pairs.
5149
def Vector_ContractionOp :
5250
Vector_Op<"contract", [NoSideEffect]>,
5351
Arguments<(ins AnyVector:$lhs, AnyVector:$rhs, AnyVector:$acc,
54-
Variadic<TupleOf<[Index]>>:$masks)>,
52+
Variadic<TupleOf<[Index]>>:$masks,
53+
AffineMapArrayAttr:$indexing_maps, ArrayAttr:$iterator_types)>,
5554
Results<(outs AnyVector)> {
5655
let summary = "vector contraction operation";
5756
let description = [{
@@ -64,39 +63,59 @@ def Vector_ContractionOp :
6463
Optional vector mask arguments specify the dynamic dimension sizes of
6564
valid data within the lhs/rhs vector arguments.
6665

67-
Dimensions for the arguments and result type fall into three categories:
68-
*) Contracting: contracting dimensions are present in the lhs and rhs
66+
An iterator type attribute list must be specified, where each element of
67+
the list represents an iterator with one of the following types:
68+
69+
*) "reduction": reduction dimensions are present in the lhs and rhs
6970
arguments but not in the output (or optional accumulator
7071
argument). These are the dimensions along which the vector
71-
contraction op computes the sum of products, and contracting
72-
dimension pair dimension sizes must match between lhs/rhs.
73-
*) Batch: batch dimensions are non-contracting dimensions and so are
74-
present in the output and in the accumulator argument. The lhs
75-
and rhs co-iterate along the batch dimension and so dimension
76-
sizes must match across all arguments and result.
77-
*) Free: free dimensions are non-contraction, non-batch dimensions and
78-
are present in the output and accumulator argument. The lhs and
79-
rhs free dimensions are unrelated to each other and do not
80-
co-iterate.
81-
82-
Contracting and batch dimensions are specified as dimension pairs
83-
of logical dimension numbers: the first in the pair represents the lhs
84-
logical dimension number and the second in the pair represents the
85-
associated rhs logical dimension number. A dimension pair binds together
86-
logical dimension numbers from the lhs/rhs which co-iterate together, either
87-
as contracting or batch dimensions.
72+
contraction op computes the sum of products, and
73+
contracting dimension pair dimension sizes must match
74+
between lhs/rhs.
75+
*) "parallel": Batch dimensions are iterator type "parallel", and
76+
are non-contracting dimensions present in the lhs, rhs and
77+
output. The lhs/rhs co-iterate along the batch dimensions,
78+
which should be expressed in their indexing maps.
79+
80+
Free dimensions are iterator type "parallel", and are
81+
non-contraction, non-batch dimensions accessed by either the
82+
lhs or rhs (but not both). The lhs and rhs free dimensions
83+
are unrelated to each other and do not co-iterate, which
84+
should be expressed in their indexing maps.
85+
86+
An indexing map attribute list must be specified with an entry for lhs, rhs
87+
and acc arguments. An indexing map attribute specifies a mapping from each
88+
iterator in the iterator type list, to each dimension of an N-D vector.
8889

8990
Examples:
9091

9192
// 2D vector contraction with one contracting dimension (matmul).
92-
%3 = vector.contract %0, %1, %2
93-
{ contracting_dim_map = [[1, 0]] }
94-
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
93+
#contraction_accesses = [
94+
(i, j, k) -> (i, k),
95+
(i, j, k) -> (k, j),
96+
(i, j, k) -> (i, j)
97+
]
98+
#contraction_trait = {
99+
indexing_maps = #contraction_accesses,
100+
iterator_types = [parallel, parallel, reduction]
101+
}
102+
103+
%3 = vector.contract #contraction_trait %0, %1, %2
104+
: vector<4x3xf32>, vector<3x7xf32> into vector<4x7xf32>
95105

96106
// 4D to 3D vector contraction with two contracting dimensions and
97107
// one batch dimension.
98-
%4 = vector.contract %0, %1, %2
99-
{ batch_dim_map = [[1, 0]], contracting_dim_map = [[0, 2], [2, 1]] }
108+
#contraction_accesses = [
109+
(b0, f0, f1, c0, c1) -> (c0, b0, c1, f0),
110+
(b0, f0, f1, c0, c1) -> (b0, c1, c0, f1),
111+
(b0, f0, f1, c0, c1) -> (b0, f0, f1)
112+
]
113+
#contraction_trait = {
114+
indexing_maps = #contraction_accesses,
115+
iterator_types = [parallel, parallel, parallel reduction, reduction]
116+
}
117+
118+
%4 = vector.contract #contraction_trait %0, %1, %2
100119
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x5xf32>
101120

102121
// 4D vector contraction with two contracting dimensions and optional
@@ -106,8 +125,7 @@ def Vector_ContractionOp :
106125
%rhs_mask = vector.make_tuple %size4, %size5, %size6, %size7
107126
: tuple<index, index, index, index>
108127

109-
%5 = vector.contract %0, %1, %2, %lhs_mask, %rhs_mask
110-
{ contracting_dim_map = [[0, 2], [2, 1]] }
128+
%5 = vector.contract #contraction_trait %0, %1, %2, %lhs_mask, %rhs_mask
111129
: vector<7x8x16x15xf32>, vector<8x16x7x5xf32> into vector<8x15x8x5xf32>
112130
}];
113131
let extraClassDeclaration = [{
@@ -131,11 +149,13 @@ def Vector_ContractionOp :
131149
VectorType getResultType() {
132150
return getResult()->getType().cast<VectorType>();
133151
}
134-
static StringRef getContractingDimMapAttrName() {
135-
return "contracting_dim_map";
152+
SmallVector<StringRef, 2> getTraitAttrNames();
153+
SmallVector<AffineMap, 4> getIndexingMaps();
154+
static StringRef getReductionIteratorTypeName() {
155+
return "reduction";
136156
}
137-
static StringRef getBatchDimMapAttrName() {
138-
return "batch_dim_map";
157+
static StringRef getParallelIteratorTypeName() {
158+
return "parallel";
139159
}
140160
std::vector<std::pair<int64_t, int64_t>> getContractingDimMap();
141161
std::vector<std::pair<int64_t, int64_t>> getBatchDimMap();

lib/Dialect/VectorOps/VectorOps.cpp

+87-17
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@
2727
#include "mlir/IR/OpImplementation.h"
2828
#include "mlir/IR/TypeUtilities.h"
2929
#include "mlir/Support/LLVM.h"
30+
#include "llvm/ADT/StringSet.h"
3031

3132
using namespace mlir;
3233
using namespace mlir::vector;
@@ -56,7 +57,10 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
5657
SmallVector<Type, 2> types;
5758
Type resultVectorType;
5859
auto loc = parser.getCurrentLocation();
59-
if (parser.parseOperand(lhsInfo) || parser.parseComma() ||
60+
DictionaryAttr dictAttr;
61+
// TODO(andydavis, ntv) Unify linalg op attribute parsing.
62+
if (parser.parseAttribute(dictAttr, "_", result.attributes) ||
63+
parser.parseOperand(lhsInfo) || parser.parseComma() ||
6064
parser.parseOperand(rhsInfo) || parser.parseComma() ||
6165
parser.parseOperand(accInfo) ||
6266
parser.parseTrailingOperandList(masksInfo) ||
@@ -68,7 +72,8 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
6872
parser.resolveOperand(accInfo, resultVectorType, result.operands) ||
6973
parser.addTypeToList(resultVectorType, result.types))
7074
return failure();
71-
75+
result.attributes.assign(dictAttr.getValue().begin(),
76+
dictAttr.getValue().end());
7277
if (masksInfo.empty())
7378
return success();
7479
if (masksInfo.size() != 2)
@@ -90,13 +95,23 @@ static ParseResult parseContractionOp(OpAsmParser &parser,
9095
}
9196

9297
static void print(OpAsmPrinter &p, ContractionOp op) {
93-
p << op.getOperationName() << " " << *op.lhs() << ", " << *op.rhs();
94-
p << ", " << *op.acc();
98+
// TODO(andydavis, ntv) Unify printing code with linalg ops.
99+
auto attrNames = op.getTraitAttrNames();
100+
llvm::StringSet<> traitAttrsSet;
101+
traitAttrsSet.insert(attrNames.begin(), attrNames.end());
102+
SmallVector<NamedAttribute, 8> attrs;
103+
for (auto attr : op.getAttrs()) {
104+
if (traitAttrsSet.count(attr.first.strref()) > 0)
105+
attrs.push_back(attr);
106+
}
107+
auto dictAttr = DictionaryAttr::get(attrs, op.getContext());
108+
p << op.getOperationName() << " " << dictAttr << " " << *op.lhs() << ", ";
109+
p << *op.rhs() << ", " << *op.acc();
95110
if (llvm::size(op.masks()) == 2) {
96111
p << ", " << **op.masks().begin();
97112
p << ", " << **(op.masks().begin() + 1);
98113
}
99-
p.printOptionalAttrDict(op.getAttrs());
114+
p.printOptionalAttrDict(op.getAttrs(), attrNames);
100115
p << " : " << op.lhs()->getType() << ", " << op.rhs()->getType() << " into "
101116
<< op.getResultType();
102117
}
@@ -159,6 +174,34 @@ static LogicalResult verify(ContractionOp op) {
159174
auto rhsType = op.getRhsType();
160175
auto accType = op.getAccType();
161176
auto resType = op.getResultType();
177+
178+
// Verify that an indexing map was specified for each vector operand.
179+
if (op.indexing_maps().size() != 3)
180+
return op.emitOpError("expected an indexing map for each vector operand");
181+
182+
// Verify that each index map has 'numIterators' inputs, no symbols, and
183+
// that the number of map outputs equals the rank of its associated
184+
// vector operand.
185+
unsigned numIterators = op.iterator_types().getValue().size();
186+
for (auto it : llvm::enumerate(op.indexing_maps())) {
187+
auto index = it.index();
188+
auto map = it.value().cast<AffineMapAttr>().getValue();
189+
if (map.getNumSymbols() != 0)
190+
return op.emitOpError("expected indexing map ")
191+
<< index << " to have no symbols";
192+
if (map.getNumDims() != numIterators)
193+
return op.emitOpError("expected indexing map ")
194+
<< index << " to have " << numIterators << " number of inputs";
195+
auto operandType = op.getOperand(index)->getType().cast<VectorType>();
196+
unsigned rank = operandType.getShape().size();
197+
if (map.getNumResults() != rank)
198+
return op.emitOpError("expected indexing map ")
199+
<< index << " to have " << rank << " number of outputs";
200+
if (!map.isProjectedPermutation())
201+
return op.emitOpError("expected indexing map ")
202+
<< index << " to be a projected permutation of its inputs";
203+
}
204+
162205
auto contractingDimMap = op.getContractingDimMap();
163206
auto batchDimMap = op.getBatchDimMap();
164207

@@ -198,27 +241,54 @@ static LogicalResult verify(ContractionOp op) {
198241
return success();
199242
}
200243

201-
static std::vector<std::pair<int64_t, int64_t>> getDimMap(Attribute attr) {
244+
SmallVector<StringRef, 2> ContractionOp::getTraitAttrNames() {
245+
return SmallVector<StringRef, 2>{"indexing_maps", "iterator_types"};
246+
}
247+
248+
static int64_t getResultIndex(AffineMap map, AffineExpr targetExpr) {
249+
for (int64_t i = 0, e = map.getNumResults(); i < e; ++i)
250+
if (targetExpr == map.getResult(i))
251+
return i;
252+
return -1;
253+
}
254+
255+
static std::vector<std::pair<int64_t, int64_t>>
256+
getDimMap(ArrayRef<AffineMap> indexingMaps, ArrayAttr iteratorTypes,
257+
StringRef targetIteratorTypeName, MLIRContext *context) {
202258
std::vector<std::pair<int64_t, int64_t>> dimMap;
203-
auto dimPairs = attr.dyn_cast_or_null<ArrayAttr>();
204-
if (!dimPairs)
205-
return dimMap;
206-
for (auto dimPairAttr : dimPairs) {
207-
auto dimPair = dimPairAttr.cast<ArrayAttr>();
208-
assert(dimPair.size() == 2);
209-
auto lhsDim = dimPair.begin()->cast<IntegerAttr>().getInt();
210-
auto rhsDim = std::prev(dimPair.end())->cast<IntegerAttr>().getInt();
211-
dimMap.push_back({lhsDim, rhsDim});
259+
for (auto it : llvm::enumerate(iteratorTypes)) {
260+
auto iteratorTypeName = it.value().cast<StringAttr>().getValue();
261+
if (iteratorTypeName != targetIteratorTypeName)
262+
continue;
263+
// Search lhs/rhs map results for 'targetExpr'.
264+
auto targetExpr = getAffineDimExpr(it.index(), context);
265+
int64_t lhsDim = getResultIndex(indexingMaps[0], targetExpr);
266+
int64_t rhsDim = getResultIndex(indexingMaps[1], targetExpr);
267+
if (lhsDim >= 0 && rhsDim >= 0)
268+
dimMap.push_back({lhsDim, rhsDim});
212269
}
213270
return dimMap;
214271
}
215272

216273
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getContractingDimMap() {
217-
return getDimMap(getAttr(getContractingDimMapAttrName()));
274+
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
275+
return getDimMap(indexingMaps, iterator_types(),
276+
getReductionIteratorTypeName(), getContext());
218277
}
219278

220279
std::vector<std::pair<int64_t, int64_t>> ContractionOp::getBatchDimMap() {
221-
return getDimMap(getAttr(getBatchDimMapAttrName()));
280+
SmallVector<AffineMap, 4> indexingMaps(getIndexingMaps());
281+
return getDimMap(indexingMaps, iterator_types(),
282+
getParallelIteratorTypeName(), getContext());
283+
}
284+
285+
SmallVector<AffineMap, 4> ContractionOp::getIndexingMaps() {
286+
SmallVector<AffineMap, 4> res;
287+
auto mapAttrs = indexing_maps().getValue();
288+
res.reserve(mapAttrs.size());
289+
for (auto mapAttr : mapAttrs)
290+
res.push_back(mapAttr.cast<AffineMapAttr>().getValue());
291+
return res;
222292
}
223293

224294
//===----------------------------------------------------------------------===//

0 commit comments

Comments
 (0)