Skip to content

Commit b821dcc

Browse files
authored
[CPU] Support 4D matmul to FullyConnected conversion (#21491)
1 parent 985d6ec commit b821dcc

File tree

7 files changed

+116
-143
lines changed

7 files changed

+116
-143
lines changed

src/plugins/intel_cpu/src/nodes/fullyconnected.cpp

+22-39
Original file line numberDiff line numberDiff line change
@@ -108,15 +108,9 @@ bool FullyConnected::isSupportedOperation(const std::shared_ptr<const ov::Node>&
108108
errorMessage = "Only Constant operation on 'bias' input is supported";
109109
return false;
110110
}
111-
const auto inRank = fc->get_input_partial_shape(DATA_ID).size();
112111
const auto weightRank = fc->get_input_partial_shape(WEIGHTS_ID).size();
113-
if (!one_of(inRank, 2u, 3u, 4u)) {
114-
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank);
115-
return false;
116-
}
117-
if ((one_of(inRank, 2u, 3u) && weightRank != 2) || (inRank == 4 && weightRank != 4)) {
118-
errorMessage = "Doesn't support 'data' input with rank: " + std::to_string(inRank) +
119-
" and 'weight' input with rank: " + std::to_string(weightRank);
112+
if (weightRank != 2) {
113+
errorMessage = "Doesn't support 'weight' input with rank: " + std::to_string(weightRank);
120114
return false;
121115
}
122116
} catch (...) {
@@ -160,16 +154,9 @@ VectorDims FullyConnected::makeDummyInputDims() const {
160154

161155
auto inMinDims = inShape.getMinDims();
162156
auto inMaxDims = inShape.getMaxDims();
157+
inMinDims.back() = weightDims.back();
158+
inMaxDims.back() = weightDims.back();
163159

164-
if (inMinDims.size() == 3) {
165-
inMinDims.back() = weightDims.back();
166-
inMaxDims.back() = weightDims.back();
167-
} else {
168-
for (size_t i = 1; i < inMinDims.size(); i++) {
169-
inMinDims[i] = weightDims[i];
170-
inMaxDims[i] = weightDims[i];
171-
}
172-
}
173160
return MemoryDescUtils::makeDummyShape(Shape(inMinDims, inMaxDims)).getStaticDims();
174161
}
175162

@@ -394,6 +381,11 @@ createDescriptorInternalForConv(DnnlMemoryDescCPtr inputDescPtr,
394381
}
395382
}
396383

384+
template <typename T>
385+
static std::vector<T> normalizeDims(const std::vector<T>& dims) {
386+
return {std::accumulate(dims.begin(), dims.end() - 1, (T)1, std::multiplies<T>()), dims[dims.size() - 1]};
387+
}
388+
397389
static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::engine& engine) {
398390
// use conv1x1 primitive for computation
399391
if (key.useConv1x1) {
@@ -407,17 +399,18 @@ static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::en
407399
// fallback to normal inner product primitive
408400
auto inDesc = key.inp0->getDnnlDesc();
409401
const auto& inDims = inDesc.get_dims(); // @TODO query + copy might be slow
410-
if (inDims.size() == 3) {
411-
auto normalizedInDims = {inDims[0] * inDims[1], inDims[2]};
402+
if (inDims.size() > 2) {
403+
dnnl::memory::dims normalizedInDims = normalizeDims(inDims);
412404
inDesc = inDesc.reshape(normalizedInDims);
413405
}
406+
414407
auto outDesc = key.out->getDnnlDesc();
415408
const auto& outDims = outDesc.get_dims(); // @TODO query + copy might be slow
416-
417-
if (outDims.size() == 3) {
418-
auto normalizedOutDims = { outDims[0] * outDims[1], outDims[2] };
409+
if (outDims.size() > 2) {
410+
dnnl::memory::dims normalizedOutDims = normalizeDims(outDims);
419411
outDesc = outDesc.reshape(normalizedOutDims);
420412
}
413+
421414
dnnl::memory::desc weiDesc;
422415
if (key.useSparseWeights) {
423416
weiDesc = key.inp1->getDnnlDesc();
@@ -673,10 +666,10 @@ void FullyConnected::execute(dnnl::stream strm) {
673666
auto updateMemoryPtr = [this](int argType) {
674667
auto param = primArgs.find(argType);
675668
if (param != primArgs.end()) {
676-
if (argType == DNNL_ARG_SRC && (getInputShapeAtPort(DATA_ID).getRank() == 3 || useConv1x1)) {
669+
if (argType == DNNL_ARG_SRC && (getInputShapeAtPort(DATA_ID).getRank() > 2 || useConv1x1)) {
677670
primArgs.at(argType).set_data_handle(getParentEdgesAtPort(0)[0]->getMemoryPtr()->getData());
678671
}
679-
if (argType == DNNL_ARG_DST && (getOutputShapeAtPort(0).getRank() == 3 || useConv1x1)) {
672+
if (argType == DNNL_ARG_DST && (getOutputShapeAtPort(0).getRank() > 2 || useConv1x1)) {
680673
primArgs.at(argType).set_data_handle(getChildEdgesAtPort(0)[0]->getMemoryPtr()->getData());
681674
}
682675
}
@@ -708,17 +701,7 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
708701
// 2D: [X,Y] [Y,Z] => [X,Z] with N=X,IC=Y,OC=Z
709702
// 3D: [B,X,Y] [Y,Z] => [B,X,Z] with N=B*X,IC=Y,OC=Z
710703

711-
VectorDims dims;
712-
if (dims_ext.size() == 2) {
713-
// 2D
714-
dims = dims_ext;
715-
} else if (dims_ext.size() == 3) {
716-
// 3D
717-
dims.push_back(dims_ext[0] * dims_ext[1]);
718-
dims.push_back(dims_ext[2]);
719-
} else {
720-
OPENVINO_THROW("Unexpected rank(", dims_ext.size(), ") for output tensor of node: ", getName());
721-
}
704+
VectorDims dims = normalizeDims(dims_ext);
722705

723706
DnnlPostOpsComposer dnnlpoc(getEngine(), attr, ops, postOpsArgs, dims, dims.size() - 1, canBeExecutedInInt8(),
724707
1 << 0, getDQScales(), withBiases);
@@ -802,11 +785,11 @@ const std::vector<impl_desc_type>& FullyConnected::getDefaultImplPriority() {
802785
void FullyConnected::createDescriptorInternal(const dnnl::memory::desc &inputDesc,
803786
const dnnl::memory::desc &outputDesc) {
804787
auto create2Dcandidate = [](const dnnl::memory::desc &desc) {
805-
if (desc.get_dims().size() != 3) // already 2D
788+
if (desc.get_dims().size() == 2) // already 2D
806789
return desc;
807790

808791
auto inDims = desc.get_dims();
809-
auto normalizedInDims = {inDims[0] * inDims[1], inDims[2]};
792+
dnnl::memory::dims normalizedInDims = normalizeDims(inDims);
810793

811794
return dnnl::memory::desc(normalizedInDims, desc.get_data_type(),
812795
DnnlExtensionUtils::GetPlainFormatByRank(normalizedInDims.size()));
@@ -967,7 +950,7 @@ void FullyConnected::initSupportedPrimitiveDescriptors() {
967950
std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
968951
auto desc = idx > 0 ? prim_desc.weights_desc(idx - 1) : prim_desc.src_desc(idx);
969952

970-
if (getInputShapeAtPort(idx).getRank() == 3
953+
if (getInputShapeAtPort(idx).getRank() != 2
971954
// report original plain layout for weight since it needs to be reordered dynamically at runtime
972955
|| (idx == 1 && !useSparseWeights)) {
973956
return std::make_shared<CpuBlockedMemoryDesc>(
@@ -984,7 +967,7 @@ std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc(const dnnl::primitive_
984967
std::shared_ptr<MemoryDesc> FullyConnected::getDstMemDesc(const dnnl::primitive_desc &prim_desc, size_t idx) const {
985968
auto desc = prim_desc.dst_desc(idx);
986969

987-
if (getOutputShapeAtPort(idx).getRank() == 3) {
970+
if (getOutputShapeAtPort(idx).getRank() != 2) {
988971
return std::make_shared<CpuBlockedMemoryDesc>(
989972
DnnlExtensionUtils::DataTypeToElementType(desc.get_data_type()), getOutputShapeAtPort(idx));
990973
}

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/convert_matmul_to_fc.cpp

+1-2
Original file line numberDiff line numberDiff line change
@@ -53,8 +53,7 @@ ov::intel_cpu::ConvertMatMulToFC::ConvertMatMulToFC() {
5353
auto rank_b = shape_b.rank().get_length();
5454

5555
// Transformation to FC is not supported for 1D inputs
56-
if (rank_a == 1 || rank_b == 1 ||
57-
rank_a > 3 || rank_b > 3) {
56+
if (rank_a == 1 || rank_b == 1) {
5857
return false;
5958
}
6059

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/reshape_fc_fusion.cpp

-76
This file was deleted.

src/plugins/intel_cpu/src/transformations/cpu_opset/common/pass/reshape_fc_fusion.hpp

-19
This file was deleted.

src/plugins/intel_cpu/src/transformations/cpu_opset/convert_to_cpu_specific_opset.hpp

-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include "openvino/pass/constant_folding.hpp"
66
#include "openvino/op/fake_quantize.hpp"
77
#include "openvino/pass/manager.hpp"
8-
#include "common/pass/reshape_fc_fusion.hpp"
98
#include "common/pass/align_matmul_input_ranks.hpp"
109
#include "transformations/common_optimizations/reshape_prelu.hpp"
1110
#include "common/pass/convert_broadcast_to_tiles.hpp"
@@ -42,9 +41,6 @@ inline void ConvertToCPUSpecificOpset(std::shared_ptr<ov::Model> &nGraphFunc) {
4241
CPU_REGISTER_PASS_COMMON(manager, ConvertToLeakyRelu);
4342
CPU_REGISTER_PASS_COMMON(manager, ConvertToSwishCPU);
4443
CPU_REGISTER_PASS_COMMON(manager, OptimizeSequenceTransposes);
45-
if (!ov::op::util::has_op_with_type<ov::op::v0::FakeQuantize>(nGraphFunc)) {
46-
CPU_REGISTER_PASS_COMMON(manager, ReshapeFullyConnectedFusion);
47-
}
4844
// after transformation "MoveEltwiseUpThroughDataMov" there can be reshaped sequences that should be eliminated or fused
4945
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ReshapeSequenceFusion);
5046
CPU_REGISTER_PASS_COMMON(manager, ov::pass::ConstantFolding);

src/plugins/intel_cpu/tests/functional/shared_tests_instances/low_precision_transformations/mat_mul_with_constant_transformation.cpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ std::vector<MatMulWithConstantTransformationTestValues> testValues = {
5151
{ std::vector<float>(4 * 2, 2.f), ngraph::element::f32, ngraph::Shape{ 2, 4 } },
5252
{ 256ul, {{1}, {1}, {2, 1}, {2, 1}}, {-128.f}, {127.f}, {-128.f, -12.8f}, {127.f, 12.7f} },
5353
{ {}, {}, {} },
54-
"MatMul",
54+
"FullyConnected",
5555
"u8"
5656
},
5757
// 4D with Dq on weights
@@ -61,7 +61,7 @@ std::vector<MatMulWithConstantTransformationTestValues> testValues = {
6161
{ std::vector<float>(4 * 2, 2.f), ngraph::element::i8, ngraph::Shape{ 2, 4 } },
6262
{},
6363
{ ngraph::element::f32, {}, {{0.1f, 0.01f}, ngraph::element::f32, ngraph::Shape{ 2, 1 }} },
64-
"MatMul",
64+
"FullyConnected",
6565
"u8"
6666
},
6767
// 3D with the same values

src/plugins/intel_cpu/tests/unit/transformations/convert_matmul_test.cpp

+91-1
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,96 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest14) {
249249
}
250250
}
251251

252+
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_4d_1) {
253+
{
254+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{2, 3, 4, 5});
255+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 6, 5 }, { 1 });
256+
auto matmul = std::make_shared<ov::opset1::MatMul>(input1, input2, false, true);
257+
258+
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{input1});
259+
manager.register_pass<ConvertMatMulToFC>();
260+
}
261+
{
262+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{2, 3, 4, 5});
263+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{ 6, 5 }, { 1 });
264+
auto fc = std::make_shared<FullyConnectedNode>(input1, input2, ov::Rank(4), ov::element::f32);
265+
266+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{input1});
267+
}
268+
}
269+
270+
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_4d_2) {
271+
{
272+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 1, 5});
273+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1, 10, 5}, {1});
274+
auto fc = std::make_shared<ov::opset1::MatMul>(input1, input2, false, true);
275+
276+
model = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{input1});
277+
manager.register_pass<ConvertMatMulToFC>();
278+
}
279+
{
280+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::PartialShape{-1, -1, 1, 5});
281+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{10, 5}, {1});
282+
auto fc = std::make_shared<FullyConnectedNode>(input1, input2, ov::Rank(4));
283+
284+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{input1});
285+
}
286+
}
287+
288+
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_4d_3) {
289+
{
290+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{2, 4});
291+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1, 1, 5, 4}, { 1 });
292+
auto matmul = std::make_shared<ov::opset1::MatMul>(input1, input2, false, true);
293+
294+
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{input1});
295+
manager.register_pass<ConvertMatMulToFC>();
296+
}
297+
{
298+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{2, 4});
299+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{5, 4}, { 1 });
300+
auto fc = std::make_shared<FullyConnectedNode>(input1, input2, ov::Rank(4), ov::element::f32);
301+
302+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{input1});
303+
}
304+
}
305+
306+
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_4d_4) {
307+
{
308+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 2, 4});
309+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1, 1, 5, 4}, { 1 });
310+
auto matmul = std::make_shared<ov::opset1::MatMul>(input1, input2, false, true);
311+
312+
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{input1});
313+
manager.register_pass<ConvertMatMulToFC>();
314+
}
315+
{
316+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{3, 2, 4});
317+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{5, 4}, { 1 });
318+
auto fc = std::make_shared<FullyConnectedNode>(input1, input2, ov::Rank(4), ov::element::f32);
319+
320+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{input1});
321+
}
322+
}
323+
324+
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_4d_5) {
325+
{
326+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{2, 3, 2, 4});
327+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{1, 1, 5, 4}, { 1 });
328+
auto matmul = std::make_shared<ov::opset1::MatMul>(input1, input2, false, true);
329+
330+
model = std::make_shared<ov::Model>(ov::NodeVector{matmul}, ov::ParameterVector{input1});
331+
manager.register_pass<ConvertMatMulToFC>();
332+
}
333+
{
334+
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{2, 3, 2, 4});
335+
auto input2 = ov::opset1::Constant::create(ov::element::f32, ov::Shape{5, 4}, { 1 });
336+
auto fc = std::make_shared<FullyConnectedNode>(input1, input2, ov::Rank(4), ov::element::f32);
337+
338+
model_ref = std::make_shared<ov::Model>(ov::NodeVector{fc}, ov::ParameterVector{input1});
339+
}
340+
}
341+
252342
TEST_F(TransformationTestsF, ConvertMatMulToFCTest_second_input_rank_adj_1) {
253343
{
254344
auto input1 = std::make_shared<ov::opset1::Parameter>(ov::element::f32, ov::Shape{5, 2, 3});
@@ -385,4 +475,4 @@ TEST_F(TransformationTestsF, ConvertMatMulToFCTest_compressed_u8_weights) {
385475

386476
model_ref = std::make_shared<ov::Model>(ov::NodeVector{ matmul }, ov::ParameterVector{ data });
387477
}
388-
}
478+
}

0 commit comments

Comments
 (0)