@@ -108,15 +108,9 @@ bool FullyConnected::isSupportedOperation(const std::shared_ptr<const ov::Node>&
108
108
errorMessage = " Only Constant operation on 'bias' input is supported" ;
109
109
return false ;
110
110
}
111
- const auto inRank = fc->get_input_partial_shape (DATA_ID).size ();
112
111
const auto weightRank = fc->get_input_partial_shape (WEIGHTS_ID).size ();
113
- if (!one_of (inRank, 2u , 3u , 4u )) {
114
- errorMessage = " Doesn't support 'data' input with rank: " + std::to_string (inRank);
115
- return false ;
116
- }
117
- if ((one_of (inRank, 2u , 3u ) && weightRank != 2 ) || (inRank == 4 && weightRank != 4 )) {
118
- errorMessage = " Doesn't support 'data' input with rank: " + std::to_string (inRank) +
119
- " and 'weight' input with rank: " + std::to_string (weightRank);
112
+ if (weightRank != 2 ) {
113
+ errorMessage = " Doesn't support 'weight' input with rank: " + std::to_string (weightRank);
120
114
return false ;
121
115
}
122
116
} catch (...) {
@@ -160,16 +154,9 @@ VectorDims FullyConnected::makeDummyInputDims() const {
160
154
161
155
auto inMinDims = inShape.getMinDims ();
162
156
auto inMaxDims = inShape.getMaxDims ();
157
+ inMinDims.back () = weightDims.back ();
158
+ inMaxDims.back () = weightDims.back ();
163
159
164
- if (inMinDims.size () == 3 ) {
165
- inMinDims.back () = weightDims.back ();
166
- inMaxDims.back () = weightDims.back ();
167
- } else {
168
- for (size_t i = 1 ; i < inMinDims.size (); i++) {
169
- inMinDims[i] = weightDims[i];
170
- inMaxDims[i] = weightDims[i];
171
- }
172
- }
173
160
return MemoryDescUtils::makeDummyShape (Shape (inMinDims, inMaxDims)).getStaticDims ();
174
161
}
175
162
@@ -394,6 +381,11 @@ createDescriptorInternalForConv(DnnlMemoryDescCPtr inputDescPtr,
394
381
}
395
382
}
396
383
384
+ template <typename T>
385
+ static std::vector<T> normalizeDims (const std::vector<T>& dims) {
386
+ return {std::accumulate (dims.begin (), dims.end () - 1 , (T)1 , std::multiplies<T>()), dims[dims.size () - 1 ]};
387
+ }
388
+
397
389
static dnnl::primitive_desc createPrimitiveDesc (const FCKey& key, const dnnl::engine& engine) {
398
390
// use conv1x1 primitive for computation
399
391
if (key.useConv1x1 ) {
@@ -407,17 +399,18 @@ static dnnl::primitive_desc createPrimitiveDesc(const FCKey& key, const dnnl::en
407
399
// fallback to normal inner product primitive
408
400
auto inDesc = key.inp0 ->getDnnlDesc ();
409
401
const auto & inDims = inDesc.get_dims (); // @TODO query + copy might be slow
410
- if (inDims.size () == 3 ) {
411
- auto normalizedInDims = { inDims[ 0 ] * inDims[ 1 ], inDims[ 2 ]} ;
402
+ if (inDims.size () > 2 ) {
403
+ dnnl::memory::dims normalizedInDims = normalizeDims ( inDims) ;
412
404
inDesc = inDesc.reshape (normalizedInDims);
413
405
}
406
+
414
407
auto outDesc = key.out ->getDnnlDesc ();
415
408
const auto & outDims = outDesc.get_dims (); // @TODO query + copy might be slow
416
-
417
- if (outDims.size () == 3 ) {
418
- auto normalizedOutDims = { outDims[0 ] * outDims[1 ], outDims[2 ] };
409
+ if (outDims.size () > 2 ) {
410
+ dnnl::memory::dims normalizedOutDims = normalizeDims (outDims);
419
411
outDesc = outDesc.reshape (normalizedOutDims);
420
412
}
413
+
421
414
dnnl::memory::desc weiDesc;
422
415
if (key.useSparseWeights ) {
423
416
weiDesc = key.inp1 ->getDnnlDesc ();
@@ -673,10 +666,10 @@ void FullyConnected::execute(dnnl::stream strm) {
673
666
auto updateMemoryPtr = [this ](int argType) {
674
667
auto param = primArgs.find (argType);
675
668
if (param != primArgs.end ()) {
676
- if (argType == DNNL_ARG_SRC && (getInputShapeAtPort (DATA_ID).getRank () == 3 || useConv1x1)) {
669
+ if (argType == DNNL_ARG_SRC && (getInputShapeAtPort (DATA_ID).getRank () > 2 || useConv1x1)) {
677
670
primArgs.at (argType).set_data_handle (getParentEdgesAtPort (0 )[0 ]->getMemoryPtr ()->getData ());
678
671
}
679
- if (argType == DNNL_ARG_DST && (getOutputShapeAtPort (0 ).getRank () == 3 || useConv1x1)) {
672
+ if (argType == DNNL_ARG_DST && (getOutputShapeAtPort (0 ).getRank () > 2 || useConv1x1)) {
680
673
primArgs.at (argType).set_data_handle (getChildEdgesAtPort (0 )[0 ]->getMemoryPtr ()->getData ());
681
674
}
682
675
}
@@ -708,17 +701,7 @@ void FullyConnected::setPostOps(dnnl::primitive_attr& attr, const VectorDims& di
708
701
// 2D: [X,Y] [Y,Z] => [X,Z] with N=X,IC=Y,OC=Z
709
702
// 3D: [B,X,Y] [Y,Z] => [B,X,Z] with N=B*X,IC=Y,OC=Z
710
703
711
- VectorDims dims;
712
- if (dims_ext.size () == 2 ) {
713
- // 2D
714
- dims = dims_ext;
715
- } else if (dims_ext.size () == 3 ) {
716
- // 3D
717
- dims.push_back (dims_ext[0 ] * dims_ext[1 ]);
718
- dims.push_back (dims_ext[2 ]);
719
- } else {
720
- OPENVINO_THROW (" Unexpected rank(" , dims_ext.size (), " ) for output tensor of node: " , getName ());
721
- }
704
+ VectorDims dims = normalizeDims (dims_ext);
722
705
723
706
DnnlPostOpsComposer dnnlpoc (getEngine (), attr, ops, postOpsArgs, dims, dims.size () - 1 , canBeExecutedInInt8 (),
724
707
1 << 0 , getDQScales (), withBiases);
@@ -802,11 +785,11 @@ const std::vector<impl_desc_type>& FullyConnected::getDefaultImplPriority() {
802
785
void FullyConnected::createDescriptorInternal (const dnnl::memory::desc &inputDesc,
803
786
const dnnl::memory::desc &outputDesc) {
804
787
auto create2Dcandidate = [](const dnnl::memory::desc &desc) {
805
- if (desc.get_dims ().size () != 3 ) // already 2D
788
+ if (desc.get_dims ().size () == 2 ) // already 2D
806
789
return desc;
807
790
808
791
auto inDims = desc.get_dims ();
809
- auto normalizedInDims = { inDims[ 0 ] * inDims[ 1 ], inDims[ 2 ]} ;
792
+ dnnl::memory::dims normalizedInDims = normalizeDims ( inDims) ;
810
793
811
794
return dnnl::memory::desc (normalizedInDims, desc.get_data_type (),
812
795
DnnlExtensionUtils::GetPlainFormatByRank (normalizedInDims.size ()));
@@ -967,7 +950,7 @@ void FullyConnected::initSupportedPrimitiveDescriptors() {
967
950
std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc (const dnnl::primitive_desc &prim_desc, size_t idx) const {
968
951
auto desc = idx > 0 ? prim_desc.weights_desc (idx - 1 ) : prim_desc.src_desc (idx);
969
952
970
- if (getInputShapeAtPort (idx).getRank () == 3
953
+ if (getInputShapeAtPort (idx).getRank () != 2
971
954
// report original plain layout for weight since it needs to be reordered dynamically at runtime
972
955
|| (idx == 1 && !useSparseWeights)) {
973
956
return std::make_shared<CpuBlockedMemoryDesc>(
@@ -984,7 +967,7 @@ std::shared_ptr<MemoryDesc> FullyConnected::getSrcMemDesc(const dnnl::primitive_
984
967
std::shared_ptr<MemoryDesc> FullyConnected::getDstMemDesc (const dnnl::primitive_desc &prim_desc, size_t idx) const {
985
968
auto desc = prim_desc.dst_desc (idx);
986
969
987
- if (getOutputShapeAtPort (idx).getRank () == 3 ) {
970
+ if (getOutputShapeAtPort (idx).getRank () != 2 ) {
988
971
return std::make_shared<CpuBlockedMemoryDesc>(
989
972
DnnlExtensionUtils::DataTypeToElementType (desc.get_data_type ()), getOutputShapeAtPort (idx));
990
973
}
0 commit comments