Skip to content

Commit 9cd5f7b

Browse files
committed
[CIR] Lower to MLIR struct with array member
Do not go through a memref of memref.
1 parent 95a5485 commit 9cd5f7b

File tree

3 files changed

+30
-13
lines changed

3 files changed

+30
-13
lines changed

clang/lib/CIR/Lowering/ThroughMLIR/LowerCIRToMLIR.cpp

+8-2
Original file line numberDiff line numberDiff line change
@@ -339,7 +339,13 @@ class CIRGetMemberOpLowering
339339
mlir::cast<mlir::named_tuple::NamedTupleType>(memref.getElementType());
340340
// The lowered type of the element to access in the named_tuple.
341341
auto loweredMemberType = namedTupleType.getType(memberIndex);
342-
auto elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
342+
// memref.view can only cast to another memref. Wrap the target type if it
343+
// is not already a memref (like with a struct with an array member)
344+
mlir::MemRefType elementMemRefTy;
345+
if (mlir::isa<mlir::MemRefType>(loweredMemberType))
346+
elementMemRefTy = mlir::cast<mlir::MemRefType>(loweredMemberType);
347+
else
348+
elementMemRefTy = mlir::MemRefType::get({}, loweredMemberType);
343349
auto offset = structLayout->getElementOffset(memberIndex);
344350
// Synthesize the byte access to right lowered type.
345351
auto byteShift =
@@ -1368,7 +1374,7 @@ class CIRPtrStrideOpLowering
13681374

13691375
// Return true if all the PtrStrideOp users are load, store or cast
13701376
// with array_to_ptrdecay kind and they are in the same block.
1371-
inline bool isLoadStoreOrCastArrayToPtrProduer(cir::PtrStrideOp op) const {
1377+
inline bool isLoadStoreOrCastArrayToPtrProducer(cir::PtrStrideOp op) const {
13721378
if (op.use_empty())
13731379
return false;
13741380
for (auto *user : op->getUsers()) {
+22-10
Original file line numberDiff line numberDiff line change
@@ -1,40 +1,52 @@
11
// RUN: %clang_cc1 -triple x86_64-unknown-linux-gnu -fno-clangir-direct-lowering -emit-mlir %s -o %t.mlir
22
// RUN: FileCheck --input-file=%t.mlir %s
33

4+
// Check the MLIR lowering of struct and member accesses
45
struct s {
56
int a;
67
double b;
78
char c;
9+
float d[5];
810
};
911

1012
int main() {
1113
s v;
12-
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>>
14+
// CHECK: %[[ALLOCA:.+]] = memref.alloca() {alignment = 8 : i64} : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>>
1315
v.a = 7;
1416
// CHECK: %[[C_7:.+]] = arith.constant 7 : i32
15-
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
17+
// CHECK: %[[I8_EQUIV_A:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
1618
// CHECK: %[[OFFSET_A:.+]] = arith.constant 0 : index
17-
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<24xi8> to memref<i32>
19+
// CHECK: %[[VIEW_A:.+]] = memref.view %[[I8_EQUIV_A]][%[[OFFSET_A]]][] : memref<40xi8> to memref<i32>
1820
// CHECK: memref.store %[[C_7]], %[[VIEW_A]][] : memref<i32>
1921

2022
v.b = 3.;
2123
// CHECK: %[[C_3:.+]] = arith.constant 3.000000e+00 : f64
22-
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
24+
// CHECK: %[[I8_EQUIV_B:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
2325
// CHECK: %[[OFFSET_B:.+]] = arith.constant 8 : index
24-
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<24xi8> to memref<f64>
26+
// CHECK: %[[VIEW_B:.+]] = memref.view %[[I8_EQUIV_B]][%[[OFFSET_B]]][] : memref<40xi8> to memref<f64>
2527
// CHECK: memref.store %[[C_3]], %[[VIEW_B]][] : memref<f64>
2628

2729
v.c = 'z';
28-
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
29-
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
30+
// CHECK: %[[C_122:.+]] = arith.constant 122 : i8
31+
// CHECK: %[[I8_EQUIV_C:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
3032
// CHECK: %[[OFFSET_C:.+]] = arith.constant 16 : index
31-
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<24xi8> to memref<i8>
33+
// CHECK: %[[VIEW_C:.+]] = memref.view %[[I8_EQUIV_C]][%[[OFFSET_C]]][] : memref<40xi8> to memref<i8>
3234
// memref.store %[[C_122]], %[[VIEW_C]][] : memref<i8>
3335

36+
v.d[4] = 6.f;
37+
// CHECK: %[[C_6:.+]] = arith.constant 6.000000e+00 : f32
38+
// CHECK: %[[I8_EQUIV_D:.+]] = named_tuple.cast %alloca_0 : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
39+
// CHECK: %[[OFFSET_D:.+]] = arith.constant 20 : index
40+
// Do not lower to a memref of memref
41+
// CHECK: %[[VIEW_D:.+]] = memref.view %3[%c20][] : memref<40xi8> to memref<5xf32>
42+
// CHECK: %[[C_4:.+]] = arith.constant 4 : i32
43+
// CHECK: %[[I_D:.+]] = arith.index_cast %[[C_4]] : i32 to index
44+
// CHECK: memref.store %[[C_6]], %[[VIEW_D]][%[[I_D]]] : memref<5xf32>
45+
3446
return v.c;
35-
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8]>> to memref<24xi8>
47+
// CHECK: %[[I8_EQUIV_C_1:.+]] = named_tuple.cast %[[ALLOCA]] : memref<!named_tuple.named_tuple<"s", [i32, f64, i8, memref<5xf32>]>> to memref<40xi8>
3648
// CHECK: %[[OFFSET_C_1:.+]] = arith.constant 16 : index
37-
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<24xi8> to memref<i8>
49+
// CHECK: %[[VIEW_C_1:.+]] = memref.view %[[I8_EQUIV_C_1]][%[[OFFSET_C_1]]][] : memref<40xi8> to memref<i8>
3850
// CHECK: %[[VALUE_C:.+]] = memref.load %[[VIEW_C_1]][] : memref<i8>
3951
// CHECK: %[[VALUE_RET:.+]] = arith.extsi %[[VALUE_C]] : i8 to i32
4052
}

mlir/include/mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h

-1
Original file line numberDiff line numberDiff line change
@@ -17,5 +17,4 @@
1717

1818
#include "mlir/Dialect/NamedTuple/IR/NamedTupleDialect.h.inc"
1919

20-
2120
#endif // MLIR_DIALECT_NAMED_TUPLE_IR_NAMED_TUPLE_DIALECT_H

0 commit comments

Comments
 (0)