@@ -285,13 +285,13 @@ class CIRStoreOpLowering : public mlir::OpConversionPattern<cir::StoreOp> {
285
285
}
286
286
};
287
287
288
- // Lower cir.get_member
288
+ // Lower cir.get_member by aliasing the result memref to the member inside the
289
+ // flattened structure as a byte array. For example
289
290
//
290
291
// clang-format off
291
- //
292
292
// %5 = cir.get_member %1[1] {name = "b"} : !cir.ptr<!named_tuple.named_tuple<"s", [i32, f64, i8]>> -> !cir.ptr<!cir.double>
293
293
//
294
- // to something like
294
+ // is lowered to something like
295
295
//
296
296
// %1 = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
297
297
// %c8 = arith.constant 8 : index
@@ -321,37 +321,30 @@ class CIRGetMemberOpLowering
321
321
// concrete datalayout, both datalayouts are the same.
322
322
auto *structLayout = dataLayout.getStructLayout (structType);
323
323
324
- // Get the lowered type: memref<!named_tuple.named_tuple<>>
325
- auto memref = mlir::cast<mlir::MemRefType>(adaptor.getAddr ().getType ());
326
324
// Alias the memref of struct to a memref of an i8 array of the same size.
327
325
const std::array linearizedSize{
328
326
static_cast <std::int64_t >(dataLayout.getTypeStoreSize (structType))};
329
- auto flattenMemRef = mlir::MemRefType::get (
330
- linearizedSize, mlir::IntegerType::get (memref. getContext (), 8 ));
327
+ auto flattenedMemRef = mlir::MemRefType::get (
328
+ linearizedSize, mlir::IntegerType::get (getContext (), 8 ));
331
329
// Use a special cast because normal memref cast cannot do such an extreme
332
330
// cast.
333
331
auto bytesMemRef = rewriter.create <mlir::named_tuple::CastOp>(
334
- op.getLoc (), mlir::TypeRange{flattenMemRef },
332
+ op.getLoc (), mlir::TypeRange{flattenedMemRef },
335
333
mlir::ValueRange{adaptor.getAddr ()});
336
334
335
+ auto pointerToMemberTypeToLower = op.getResultTy ();
336
+ // The lowered type of the cir.ptr to the cir.struct member.
337
+ auto memrefToLoweredMemberType =
338
+ typeConverter->convertType (pointerToMemberTypeToLower);
339
+ // Synthesize the byte access to right lowered type.
337
340
auto memberIndex = op.getIndex ();
338
- auto namedTupleType =
339
- mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType ());
340
- // The lowered type of the element to access in the named_tuple.
341
- auto loweredMemberType = namedTupleType.getType (memberIndex);
342
- // memref.view can only cast to another memref. Wrap the target type if it
343
- // is not already a memref (like with a struct with an array member)
344
- mlir::MemRefType elementMemRefTy;
345
- if (mlir::isa<mlir::MemRefType>(loweredMemberType))
346
- elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
347
- else
348
- elementMemRefTy = mlir::MemRefType::get ({}, loweredMemberType);
349
341
auto offset = structLayout->getElementOffset (memberIndex);
350
- // Synthesize the byte access to right lowered type.
351
342
auto byteShift =
352
343
rewriter.create <mlir::arith::ConstantIndexOp>(op.getLoc (), offset);
344
+ // Create the memref pointing to the flattened member location.
353
345
rewriter.replaceOpWithNewOp <mlir::memref::ViewOp>(
354
- op, elementMemRefTy, bytesMemRef, byteShift, mlir::ValueRange{});
346
+ op, memrefToLoweredMemberType, bytesMemRef, byteShift,
347
+ mlir::ValueRange{});
355
348
return mlir::LogicalResult::success ();
356
349
}
357
350
};
@@ -1456,6 +1449,29 @@ void populateCIRToMLIRConversionPatterns(mlir::RewritePatternSet &patterns,
1456
1449
cirDataLayout);
1457
1450
}
1458
1451
1452
+ namespace {
1453
+ // Lower a cir.array either as a memref when it has a reference semantics or as
1454
+ // a tensor when it has a value semantics (like inside a struct or union)
1455
+ mlir::Type lowerArrayType (cir::ArrayType type, bool hasValueSemantics,
1456
+ mlir::TypeConverter &converter) {
1457
+ SmallVector<int64_t > shape;
1458
+ mlir::Type curType = type;
1459
+ while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1460
+ shape.push_back (arrayType.getSize ());
1461
+ curType = arrayType.getEltType ();
1462
+ }
1463
+ auto elementType = converter.convertType (curType);
1464
+ // FIXME: The element type might not be converted
1465
+ if (!elementType)
1466
+ return nullptr ;
1467
+ // Arrays in C/C++ have a reference semantics when not in a struct, so use
1468
+ // a memref
1469
+ if (hasValueSemantics)
1470
+ return mlir::RankedTensorType::get (shape, elementType);
1471
+ return mlir::MemRefType::get (shape, elementType);
1472
+ }
1473
+ } // namespace
1474
+
1459
1475
mlir::TypeConverter prepareTypeConverter (mlir::DataLayout &dataLayout) {
1460
1476
mlir::TypeConverter converter;
1461
1477
converter.addConversion ([&](cir::PointerType type) -> mlir::Type {
@@ -1464,6 +1480,7 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1464
1480
if (!ty)
1465
1481
return nullptr ;
1466
1482
if (isa<cir::ArrayType>(type.getPointee ()))
1483
+ // An array is already lowered as a memref with reference semantics
1467
1484
return ty;
1468
1485
return mlir::MemRefType::get ({}, ty);
1469
1486
});
@@ -1503,23 +1520,23 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1503
1520
return mlir::BFloat16Type::get (type.getContext ());
1504
1521
});
1505
1522
converter.addConversion ([&](cir::ArrayType type) -> mlir::Type {
1506
- SmallVector<int64_t > shape;
1507
- mlir::Type curType = type;
1508
- while (auto arrayType = dyn_cast<cir::ArrayType>(curType)) {
1509
- shape.push_back (arrayType.getSize ());
1510
- curType = arrayType.getEltType ();
1511
- }
1512
- auto elementType = converter.convertType (curType);
1513
- // FIXME: The element type might not be converted
1514
- if (!elementType)
1515
- return nullptr ;
1516
- return mlir::MemRefType::get (shape, elementType);
1523
+ // Arrays in C/C++ have a reference semantics when not in a
1524
+ // class/struct/union, so use a memref.
1525
+ return lowerArrayType (type, /* hasValueSemantics */ false , converter);
1517
1526
});
1518
1527
converter.addConversion ([&](cir::VectorType type) -> mlir::Type {
1519
1528
auto ty = converter.convertType (type.getEltType ());
1520
1529
return mlir::VectorType::get (type.getSize (), ty);
1521
1530
});
1522
1531
converter.addConversion ([&](cir::StructType type) -> mlir::Type {
1532
+ auto convertWithValueSemanticsArray = [&](mlir::Type t) {
1533
+ if (mlir::isa<cir::ArrayType>(t))
1534
+ // Inside a class/struct/union, an array has value semantics and is
1535
+ // lowered as a tensor.
1536
+ return lowerArrayType (mlir::cast<cir::ArrayType>(t),
1537
+ /* hasValueSemantics */ true , converter);
1538
+ return converter.convertType (t);
1539
+ };
1523
1540
// FIXME(cir): create separate unions, struct, and classes types.
1524
1541
// Convert struct members.
1525
1542
llvm::SmallVector<mlir::Type> mlirMembers;
@@ -1528,13 +1545,13 @@ mlir::TypeConverter prepareTypeConverter(mlir::DataLayout &dataLayout) {
1528
1545
// TODO(cir): This should be properly validated.
1529
1546
case cir::StructType::Struct:
1530
1547
for (auto ty : type.getMembers ())
1531
- mlirMembers.push_back (converter. convertType (ty));
1548
+ mlirMembers.push_back (convertWithValueSemanticsArray (ty));
1532
1549
break ;
1533
1550
// Unions are lowered as only the largest member.
1534
1551
case cir::StructType::Union: {
1535
1552
auto largestMember = type.getLargestMember (dataLayout);
1536
1553
if (largestMember)
1537
- mlirMembers.push_back (converter. convertType (largestMember));
1554
+ mlirMembers.push_back (convertWithValueSemanticsArray (largestMember));
1538
1555
break ;
1539
1556
}
1540
1557
}
0 commit comments