Skip to content

Commit 0e48570

Browse files
committed
fixed dynamic shape test cases
1 parent 7681bf9 commit 0e48570

File tree

2 files changed

+40
-1
lines changed

2 files changed

+40
-1
lines changed

src/plugins/intel_cpu/src/nodes/fullyconnected.cpp

+23-1
Original file line numberDiff line numberDiff line change
@@ -464,11 +464,33 @@ void FullyConnected::prepareWeightsUsingDummyShape() {
464464
if (selected_pd == nullptr)
465465
OPENVINO_THROW("Preferable primitive descriptor is not set for node ", getName(), ".");
466466

467-
auto inDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(*getBaseMemDescAtInputPort(DATA_ID)));
467+
DnnlMemoryDescPtr inDesc = nullptr;
468468
auto weightDesc = MemoryDescUtils::convertToDnnlMemoryDesc(weightDescIP);
469469
auto biasDesc = withBiases ? MemoryDescUtils::convertToDnnlMemoryDesc(getBaseMemDescAtInputPort(BIAS_ID)) : nullptr;
470470
auto outDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(*getBaseMemDescAtOutputPort(0)));
471471

472+
Shape newInShape = getBaseMemDescAtInputPort(DATA_ID)->getShape();
473+
if (isDynamicNode()) {
474+
auto originalInDesc = getBaseMemDescAtInputPort(DATA_ID);
475+
auto originalInDims = originalInDesc->getShape().getDims();
476+
size_t dimIdx = originalInDims.size() == 3 ? 1 : 0;
477+
// Propagate N dim from the output shape to the input shape
478+
if (newInShape.getDims()[dimIdx] == Shape::UNDEFINED_DIM &&
479+
getBaseMemDescAtOutputPort(0)->getShape().getDims()[dimIdx] != Shape::UNDEFINED_DIM) {
480+
newInShape = cloneShapeWithNewDim(newInShape, getBaseMemDescAtOutputPort(0)->getShape().getDims()[dimIdx], dimIdx);
481+
}
482+
// Propagate K dim from the weights shape to the input shape
483+
if (newInShape.getDims()[dimIdx+1] == Shape::UNDEFINED_DIM &&
484+
weightDesc->getShape().getDims()[1] != Shape::UNDEFINED_DIM) {
485+
newInShape = cloneShapeWithNewDim(newInShape, weightDesc->getShape().getDims()[1], dimIdx+1);
486+
}
487+
488+
auto newInDesc = DnnlBlockedMemoryDesc(originalInDesc->getPrecision(), MemoryDescUtils::makeDummyShape(newInShape));
489+
inDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(newInDesc));
490+
} else {
491+
inDesc = MemoryDescUtils::convertToDnnlMemoryDesc(MemoryDescUtils::makeDummyDesc(*getBaseMemDescAtInputPort(DATA_ID)));
492+
}
493+
472494
const FCKey key = {inDesc,
473495
weightDesc,
474496
biasDesc,

src/plugins/intel_cpu/src/utils/cpu_utils.hpp

+17
Original file line numberDiff line numberDiff line change
@@ -48,6 +48,23 @@ inline std::vector<size_t> getNormalizedDimsBySize(const VectorDims &dims, size_
4848
return normalizedDims;
4949
}
5050

51+
/**
52+
* @brief Clones passed shape and replaces one its dimention.
53+
* @param originalShape
54+
* shape to clone
55+
* @param newDimValue
56+
* new dimention value
57+
* @param dim
58+
* dimention index
59+
* @return cloned shape
60+
*/
61+
inline Shape cloneShapeWithNewDim(Shape originalShape, Dim newDimValue, size_t dim) {
62+
VectorDims newDims = originalShape.getDims();
63+
assert(dim < newDims.size());
64+
newDims[dim] = newDimValue;
65+
return Shape(originalShape.getMinDims(), newDims);
66+
}
67+
5168
/**
5269
* @brief Checked that secondInputDims unidirectional broadcastable per tensor or per channel to firstInputDims
5370
* @param firstInputDims

0 commit comments

Comments
 (0)