From 4bfb87b9c29cf4874659e2d3d8a558a1e24e8c25 Mon Sep 17 00:00:00 2001 From: Zhen Peng Date: Fri, 4 Apr 2025 17:03:29 -0700 Subject: [PATCH] [3D] Fixed bug in lowering CSF format for 3D tensor (#80) --- .../Transforms/TensorDeclLowering.cpp | 4 +- lib/ExecutionEngine/SparseUtils.cpp | 64 +++++++++---------- 2 files changed, 34 insertions(+), 34 deletions(-) diff --git a/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp b/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp index 2a6ab3fd..bd02aa78 100644 --- a/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp +++ b/lib/Dialect/TensorAlgebra/Transforms/TensorDeclLowering.cpp @@ -1110,7 +1110,7 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, Value sptensor; if (rank_size == 2) { - Value dims = rewriter.create(loc, ValueRange{array_sizes[9], array_sizes[10]}); + Value dims = rewriter.create(loc, ValueRange{array_sizes[9], array_sizes[10]}); /// I, J sptensor = rewriter.create(loc, ty, dims, /// Dim sizes ValueRange{ @@ -1134,7 +1134,7 @@ Value insertSparseTensorDeclOp(PatternRewriter & rewriter, } else if (rank_size == 3) { - Value dims = rewriter.create(loc, ValueRange{array_sizes[16], array_sizes[17], array_sizes[18]}); + Value dims = rewriter.create(loc, ValueRange{array_sizes[13], array_sizes[14], array_sizes[15]}); /// I, J, K sptensor = rewriter.create(loc, ty, dims, ValueRange { diff --git a/lib/ExecutionEngine/SparseUtils.cpp b/lib/ExecutionEngine/SparseUtils.cpp index aec14177..9c40a225 100644 --- a/lib/ExecutionEngine/SparseUtils.cpp +++ b/lib/ExecutionEngine/SparseUtils.cpp @@ -2118,22 +2118,22 @@ void read_input_sizes_3D(int32_t fileID, /// std::cout << "CSF format\n"; Csf3DTensor csf_3dtensor(FileReader.coo_3dtensor); - desc_sizes->data[0] = csf_3dtensor.A1pos_size; - desc_sizes->data[1] = csf_3dtensor.A1crd_size; - desc_sizes->data[2] = 0; - desc_sizes->data[3] = 0; - desc_sizes->data[4] = csf_3dtensor.A2pos_size; - desc_sizes->data[5] = csf_3dtensor.A2crd_size; - desc_sizes->data[6] = 0; - desc_sizes->data[7] = 0; - desc_sizes->data[8] = csf_3dtensor.A3pos_size; - desc_sizes->data[9] = csf_3dtensor.A3crd_size; - desc_sizes->data[10] = 0; - desc_sizes->data[11] = 0; - desc_sizes->data[12] = csf_3dtensor.Aval_size; - desc_sizes->data[13] = csf_3dtensor.num_index_i; - desc_sizes->data[14] = csf_3dtensor.num_index_j; - desc_sizes->data[15] = csf_3dtensor.num_index_k; + desc_sizes->data[0] = csf_3dtensor.A1pos_size; /// A1pos + desc_sizes->data[1] = csf_3dtensor.A1crd_size; /// A1crd + desc_sizes->data[2] = 0; /// A1_tile_pos + desc_sizes->data[3] = 0; /// A1_tile_crd + desc_sizes->data[4] = csf_3dtensor.A2pos_size; /// A2pos + desc_sizes->data[5] = csf_3dtensor.A2crd_size; /// A2crd + desc_sizes->data[6] = 0; /// A2_tile_pos + desc_sizes->data[7] = 0; /// A2_tile_crd + desc_sizes->data[8] = csf_3dtensor.A3pos_size; /// A3pos + desc_sizes->data[9] = csf_3dtensor.A3crd_size; /// A3crd + desc_sizes->data[10] = 0; /// A3_tile_pos + desc_sizes->data[11] = 0; /// A3_tile_crd + desc_sizes->data[12] = csf_3dtensor.Aval_size; /// Aval + desc_sizes->data[13] = csf_3dtensor.num_index_i; /// I + desc_sizes->data[14] = csf_3dtensor.num_index_j; /// J + desc_sizes->data[15] = csf_3dtensor.num_index_k; /// K } /// Mode-Generic else if (A1format == Compressed_nonunique && A2format == singleton && A3format == Dense) @@ -2141,22 +2141,22 @@ void read_input_sizes_3D(int32_t fileID, /// std::cout << "Mode-Generic format\n"; Mg3DTensor mg_3dtensor(FileReader.coo_3dtensor); - desc_sizes->data[0] = mg_3dtensor.A1pos_size; - desc_sizes->data[1] = mg_3dtensor.A1crd_size; - desc_sizes->data[2] = 0; - desc_sizes->data[3] = 0; - desc_sizes->data[4] = mg_3dtensor.A2pos_size; - desc_sizes->data[5] = mg_3dtensor.A2crd_size; - desc_sizes->data[6] = 0; - desc_sizes->data[7] = 0; - desc_sizes->data[8] = mg_3dtensor.A3pos_size; - desc_sizes->data[9] = mg_3dtensor.A3crd_size; - desc_sizes->data[10] = 0; - desc_sizes->data[11] = 0; - desc_sizes->data[12] = mg_3dtensor.Aval_size; - desc_sizes->data[13] = mg_3dtensor.num_index_i; - desc_sizes->data[14] = mg_3dtensor.num_index_j; - desc_sizes->data[15] = mg_3dtensor.num_index_k; + desc_sizes->data[0] = mg_3dtensor.A1pos_size; /// A1pos + desc_sizes->data[1] = mg_3dtensor.A1crd_size; /// A1crd + desc_sizes->data[2] = 0; /// A1_tile_pos + desc_sizes->data[3] = 0; /// A1_tile_crd + desc_sizes->data[4] = mg_3dtensor.A2pos_size; /// A2pos + desc_sizes->data[5] = mg_3dtensor.A2crd_size; /// A2crd + desc_sizes->data[6] = 0; /// A2_tile_pos + desc_sizes->data[7] = 0; /// A2_tile_crd + desc_sizes->data[8] = mg_3dtensor.A3pos_size; /// A3pos + desc_sizes->data[9] = mg_3dtensor.A3crd_size; /// A3crd + desc_sizes->data[10] = 0; /// A3_tile_pos + desc_sizes->data[11] = 0; /// A3_tile_crd + desc_sizes->data[12] = mg_3dtensor.Aval_size; /// Aval + desc_sizes->data[13] = mg_3dtensor.num_index_i; /// I + desc_sizes->data[14] = mg_3dtensor.num_index_j; /// J + desc_sizes->data[15] = mg_3dtensor.num_index_k; /// K } else {