Skip to content

Commit b5a29d4

Browse files
committed
[CIR][Lowering] Lower arrays in class/struct/union as tensor
Arrays in C/C++ have usually a reference semantics and can be lowered to memref. But when inside a class/struct/union, arrays hav a value semantics and can be lowered as tensor.
1 parent 9cd5f7b commit b5a29d4

File tree

2 files changed

+59
-42
lines changed

2 files changed

+59
-42
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+51-34
Original file line numberDiff line numberDiff line change
@@ -285,13 +285,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
285285
}
286286
};
287287

288-
// Lower cir.get_member
288+
// Lower cir.get_member by aliasing the result memref to the member inside the
289+
// flattened structure as a byte array. For example
289290
//
290291
// clang-format off
291-
//
292292
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
293293
//
294-
// to something like
294+
// is lowered to something like
295295
//
296296
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
297297
// %c8 = arith.constant 8 : index
@@ -321,37 +321,30 @@ class CIRGetMemberOpLowering
321321
// concrete datalayout, both datalayouts are the same.
322322
auto *structLayout = dataLayout.getStructLayout(structType);
323323

324-
// Get the lowered type: memref<!named_tuple.named_tuple<>>
325-
auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr().getType());
326324
// Alias the memref of struct to a memref of an i8 array of the same size.
327325
const std::array linearizedSize{
328326
static_cast<std::int64_t>(dataLayout.getTypeStoreSize(structType))};
329-
auto flattenMemRef = mlir::MemRefType::get(
330-
linearizedSize, mlir::IntegerType::get(memref.getContext(), 8));
327+
auto flattenedMemRef = mlir::MemRefType::get(
328+
linearizedSize, mlir::IntegerType::get(getContext(), 8));
331329
// Use a special cast because normal memref cast cannot do such an extreme
332330
// cast.
333331
auto bytesMemRef = rewriter.create<mlir::named_tuple::CastOp>(
334-
op.getLoc(), mlir::TypeRange{flattenMemRef},
332+
op.getLoc(), mlir::TypeRange{flattenedMemRef},
335333
mlir::ValueRange{adaptor.getAddr()});
336334

335+
auto pointerToMemberTypeToLower = op.getResultTy();
336+
// The lowered type of the cir.ptr to the cir.struct member.
337+
auto memrefToLoweredMemberType =
338+
typeConverter->convertType(pointerToMemberTypeToLower);
339+
// Synthesize the byte access to right lowered type.
337340
auto memberIndex = op.getIndex();
338-
auto namedTupleType =
339-
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
340-
// The lowered type of the element to access in the named_tuple.
341-
auto loweredMemberType = namedTupleType.getType(memberIndex);
342-
// memref.view can only cast to another memref. Wrap the target type if it
343-
// is not already a memref (like with a struct with an array member)
344-
mlir::MemRefType elementMemRefTy;
345-
if (mlir::isa<mlir::MemRefType>(loweredMemberType))
346-
elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
347-
else
348-
elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
349341
auto offset = structLayout->getElementOffset(memberIndex);
350-
// Synthesize the byte access to right lowered type.
351342
auto byteShift =
352343
rewriter.create<mlir::arith::ConstantIndexOp>(op.getLoc(), offset);
344+
// Create the memref pointing to the flattened member location.
353345
rewriter.replaceOpWithNewOp<mlir::memref::ViewOp>(
354-
op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
346+
op, memrefToLoweredMemberType, bytesMemRef, byteShift,
347+
mlir::ValueRange{});
355348
return mlir::LogicalResult::success();
356349
}
357350
};
@@ -1456,6 +1449,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
14561449
cirDataLayout);
14571450
}
14581451

1452+
namespace {
1453+
// Lower a cir.array either as a memref when it has a reference semantics or as
1454+
// a tensor when it has a value semantics (like inside a struct or union)
1455+
mlir::Type lowerArrayType(cir::ArrayType type, bool hasValueSemantics,
1456+
mlir::TypeConverter &converter) {
1457+
SmallVector<int64_t> shape;
1458+
mlir::Type curType = type;
1459+
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1460+
shape.push_back(arrayType.getSize());
1461+
curType = arrayType.getEltType();
1462+
}
1463+
auto elementType = converter.convertType(curType);
1464+
// FIXME: The element type might not be converted
1465+
if (!elementType)
1466+
return nullptr;
1467+
// Arrays in C/C++ have a reference semantics when not in a struct, so use
1468+
// a memref
1469+
if (hasValueSemantics)
1470+
return mlir::RankedTensorType::get(shape, elementType);
1471+
return mlir::MemRefType::get(shape, elementType);
1472+
}
1473+
} // namespace
1474+
14591475
mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14601476
mlir::TypeConverter converter;
14611477
converter.addConversion([&](cir::PointerType type) -> mlir::Type {
@@ -1464,6 +1480,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
14641480
if (!ty)
14651481
return nullptr;
14661482
if (isa<cir::ArrayType>(type.getPointee()))
1483+
// An array is already lowered as a memref with reference semantics
14671484
return ty;
14681485
return mlir::MemRefType::get({}, ty);
14691486
});
@@ -1503,23 +1520,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15031520
return mlir::BFloat16Type::get(type.getContext());
15041521
});
15051522
converter.addConversion([&](cir::ArrayType type) -> mlir::Type {
1506-
SmallVector<int64_t> shape;
1507-
mlir::Type curType = type;
1508-
while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1509-
shape.push_back(arrayType.getSize());
1510-
curType = arrayType.getEltType();
1511-
}
1512-
auto elementType = converter.convertType(curType);
1513-
// FIXME: The element type might not be converted
1514-
if (!elementType)
1515-
return nullptr;
1516-
return mlir::MemRefType::get(shape, elementType);
1523+
// Arrays in C/C++ have a reference semantics when not in a
1524+
// class/struct/union, so use a memref.
1525+
return lowerArrayType(type, /* hasValueSemantics */ false, converter);
15171526
});
15181527
converter.addConversion([&](cir::VectorType type) -> mlir::Type {
15191528
auto ty = converter.convertType(type.getEltType());
15201529
return mlir::VectorType::get(type.getSize(), ty);
15211530
});
15221531
converter.addConversion([&](cir::StructType type) -> mlir::Type {
1532+
auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1533+
if (mlir::isa<cir::ArrayType>(t))
1534+
// Inside a class/struct/union, an array has value semantics and is
1535+
// lowered as a tensor.
1536+
return lowerArrayType(mlir::cast<cir::ArrayType>(t),
1537+
/* hasValueSemantics */ true, converter);
1538+
return converter.convertType(t);
1539+
};
15231540
// FIXME(cir): create separate unions, struct, and classes types.
15241541
// Convert struct members.
15251542
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1528,13 +1545,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
15281545
// TODO(cir): This should be properly validated.
15291546
case cir::StructType::Struct:
15301547
for (auto ty : type.getMembers())
1531-
mlirMembers.push_back(converter.convertType(ty));
1548+
mlirMembers.push_back(convertWithValueSemanticsArray(ty));
15321549
break;
15331550
// Unions are lowered as only the largest member.
15341551
case cir::StructType::Union: {
15351552
auto largestMember = type.getLargestMember(dataLayout);
15361553
if (largestMember)
1537-
mlirMembers.push_back(converter.convertType(largestMember));
1554+
mlirMembers.push_back(convertWithValueSemanticsArray(largestMember));
15381555
break;
15391556
}
15401557
}

clang/test/CIR/Lowering/ThroughMLIR/struct.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -11,40 +11,40 @@ struct s {
1111

1212
int main() {
1313
s v;
14-
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>>
14+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>>
1515
v.a = 7;
1616
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
17-
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
17+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
1818
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
1919
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<40xi8> to memref<i32>
2020
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
2121

2222
v.b = 3.;
2323
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
24-
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
24+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
2525
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
2626
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<40xi8> to memref<f64>
2727
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
2828

2929
v.c = 'z';
3030
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
31-
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
31+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
3232
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
3333
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<40xi8> to memref<i8>
3434
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
3535

36+
auto& a = v.d;
3637
v.d[4] = 6.f;
3738
// CHECK: %[[C_6:.+]] = arith.constant 6.000000e+00 : f32
38-
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
39+
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
3940
// CHECK: %[[OFFSET_D:.+]] = arith.constant 20 : index
40-
// Do not lower to a memref of memref
41-
// CHECK: %[[VIEW_D:.+]] = memref.view %3[%c20][] : memref<40xi8> to memref<5xf32>
41+
// CHECK: %[[VIEW_D:.+]] = memref.view %[[I8_EQUIV_D]][%[[OFFSET_D]]][] : memref<40xi8> to memref<5xf32>
4242
// CHECK: %[[C_4:.+]] = arith.constant 4 : i32
4343
// CHECK: %[[I_D:.+]] = arith.index_cast %[[C_4]] : i32 to index
4444
// CHECK: memref.store %[[C_6]], %[[VIEW_D]][%[[I_D]]] : memref<5xf32>
4545

4646
return v.c;
47-
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
47+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, tensor<5xf32>]>> to memref<40xi8>
4848
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
4949
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<40xi8> to memref<i8>
5050
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>

0 commit comments

Comments
 (0)