Skip to content

Commit 2507d89

Browse files
authored
Handle dynamic rank in TSUnsqueezeBackward transformation (#26786)
### Details: Handle dynamic rank in TSUnsqueezeBackward transformation ### Tickets: - *CVS-152373*
1 parent 9b0d209 commit 2507d89

File tree

2 files changed

+54
-3
lines changed

2 files changed

+54
-3
lines changed

src/common/transformations/src/transformations/transpose_sinking/ts_unsqueeze.cpp

+13-3
Original file line numberDiff line numberDiff line change
@@ -190,9 +190,19 @@ TSUnsqueezeBackward::TSUnsqueezeBackward() {
190190
return false;
191191
}
192192
} else {
193-
auto rank = main_node->get_output_partial_shape(0).rank();
194-
non_negative_axes =
195-
util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node);
193+
const auto& axes = unsqueeze_axes->cast_vector<int64_t>();
194+
if (std::all_of(axes.begin(), axes.end(), [](int64_t axis) {
195+
return axis >= 0;
196+
})) {
197+
non_negative_axes = std::vector<size_t>(axes.begin(), axes.end());
198+
} else {
199+
auto rank = main_node->get_output_partial_shape(0).rank();
200+
if (rank.is_dynamic()) {
201+
return false;
202+
}
203+
non_negative_axes =
204+
util::try_get_normalized_axis_vector(unsqueeze_axes->get_tensor_view(), rank, *main_node);
205+
}
196206
}
197207

198208
auto transpose_order_values = transpose_order->cast_vector<size_t>();

src/common/transformations/tests/transpose_sinking/ts_common_test.cpp

+41
Original file line numberDiff line numberDiff line change
@@ -1636,6 +1636,47 @@ auto test_backward_reshape_unsqueeze = []() {
16361636
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonReshapeUnsqueezeBackward,
16371637
TSTestFixture,
16381638
test_backward_reshape_unsqueeze());
1639+
1640+
auto test_backward_unsqueeze_dyn_rank = []() {
1641+
TestCase test_case;
1642+
1643+
// Initialize common attributes
1644+
test_case.transformation = CREATE_PASS_FACTORY(TSUnsqueezeBackward);
1645+
test_case.num_main_ops = {1};
1646+
test_case.inputs_to_main = {
1647+
parameter(element::f32, PartialShape::dynamic()),
1648+
constant<int64_t>(element::i32, {2}, {-1}),
1649+
};
1650+
1651+
auto dyn_transpose = [](const vector<size_t>& idxs, const OutputVector& out_vec) -> OutputVector {
1652+
OutputVector result = out_vec;
1653+
for (const auto& idx : idxs) {
1654+
const auto& out = out_vec[idx];
1655+
1656+
// fill the order const with the stub values {-1, -2}
1657+
auto order = make_shared<Constant>(element::i32, Shape{2}, vector<int64_t>{-1, -2});
1658+
auto transpose = make_shared<Transpose>(out, order);
1659+
result[idx] = transpose;
1660+
}
1661+
return result;
1662+
};
1663+
1664+
// Test model description:
1665+
test_case.model.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
1666+
test_case.model.preprocess_outputs_of_main = {{dyn_transpose}, {{0}}};
1667+
test_case.model.model_template = create_model;
1668+
1669+
// Ref model description, the same as the original model, the transformation is not applied
1670+
// it's expected.
1671+
test_case.model_ref.main_op = {CREATE_BINARY_FACTORY(Unsqueeze)};
1672+
test_case.model_ref.preprocess_outputs_of_main = {{dyn_transpose}, {{0}}};
1673+
test_case.model_ref.model_template = create_model;
1674+
return wrapper(test_case);
1675+
};
1676+
1677+
INSTANTIATE_TEST_SUITE_P(TransposeSinkingCommonUnsqueezeBackwardDynRank,
1678+
TSTestFixture,
1679+
test_backward_unsqueeze_dyn_rank());
16391680
} // namespace common
16401681
} // namespace testing
16411682
} // namespace transpose_sinking

0 commit comments

Comments
 (0)