@@ -464,11 +464,33 @@ void FullyConnected::prepareWeightsUsingDummyShape() {
464
464
if (selected_pd == nullptr )
465
465
OPENVINO_THROW (" Preferable primitive descriptor is not set for node " , getName (), " ." );
466
466
467
- auto inDesc = MemoryDescUtils::convertToDnnlMemoryDesc ( MemoryDescUtils::makeDummyDesc (* getBaseMemDescAtInputPort (DATA_ID))) ;
467
+ DnnlMemoryDescPtr inDesc = nullptr ;
468
468
auto weightDesc = MemoryDescUtils::convertToDnnlMemoryDesc (weightDescIP);
469
469
auto biasDesc = withBiases ? MemoryDescUtils::convertToDnnlMemoryDesc (getBaseMemDescAtInputPort (BIAS_ID)) : nullptr ;
470
470
auto outDesc = MemoryDescUtils::convertToDnnlMemoryDesc (MemoryDescUtils::makeDummyDesc (*getBaseMemDescAtOutputPort (0 )));
471
471
472
+ Shape newInShape = getBaseMemDescAtInputPort (DATA_ID)->getShape ();
473
+ if (isDynamicNode ()) {
474
+ auto originalInDesc = getBaseMemDescAtInputPort (DATA_ID);
475
+ auto originalInDims = originalInDesc->getShape ().getDims ();
476
+ size_t dimIdx = originalInDims.size () == 3 ? 1 : 0 ;
477
+ // Propagate N dim from the output shape to the input shape
478
+ if (newInShape.getDims ()[dimIdx] == Shape::UNDEFINED_DIM &&
479
+ getBaseMemDescAtOutputPort (0 )->getShape ().getDims ()[dimIdx] != Shape::UNDEFINED_DIM) {
480
+ newInShape = cloneShapeWithNewDim (newInShape, getBaseMemDescAtOutputPort (0 )->getShape ().getDims ()[dimIdx], dimIdx);
481
+ }
482
+ // Propagate K dim from the weights shape to the input shape
483
+ if (newInShape.getDims ()[dimIdx+1 ] == Shape::UNDEFINED_DIM &&
484
+ weightDesc->getShape ().getDims ()[1 ] != Shape::UNDEFINED_DIM) {
485
+ newInShape = cloneShapeWithNewDim (newInShape, weightDesc->getShape ().getDims ()[1 ], dimIdx+1 );
486
+ }
487
+
488
+ auto newInDesc = DnnlBlockedMemoryDesc (originalInDesc->getPrecision (), MemoryDescUtils::makeDummyShape (newInShape));
489
+ inDesc = MemoryDescUtils::convertToDnnlMemoryDesc (MemoryDescUtils::makeDummyDesc (newInDesc));
490
+ } else {
491
+ inDesc = MemoryDescUtils::convertToDnnlMemoryDesc (MemoryDescUtils::makeDummyDesc (*getBaseMemDescAtInputPort (DATA_ID)));
492
+ }
493
+
472
494
const FCKey key = {inDesc,
473
495
weightDesc,
474
496
biasDesc,
0 commit comments