Skip to content

Commit 7365188

Browse files
authored
[SI] Removes unpredictable rt_info copying from partial value propagation (#29428)
### Details: - *Current rt_info propagation function doesn't do what it says it should do. And it works in a predictably bad manner.* ### Tickets: - *DeepSeek R1 related*
1 parent c303a8e commit 7365188

File tree

3 files changed

+5
-32
lines changed

3 files changed

+5
-32
lines changed

src/common/transformations/src/transformations/symbolic_transformations/dereshape_matmul.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -113,6 +113,7 @@ ov::Output<ov::Node> get_target_shape_from_sources(const ov::Output<ov::Node>& b
113113
if (curr_is_const && next_is_const) {
114114
dims[curr_i] = nullptr;
115115
dims[curr_i + 1] = ov::op::util::make_try_fold<ov::op::v0::Concat>(ov::NodeVector{curr_node, next_node}, 0);
116+
ov::copy_runtime_info(copy_rt_info_from, dims[curr_i + 1]);
116117
}
117118
}
118119
dims.erase(std::remove_if(dims.begin(),
@@ -327,7 +328,8 @@ ov::pass::DeReshapeMatMul::DeReshapeMatMul() {
327328
auto other_input_reshape =
328329
op::util::make_try_fold<ov::op::v1::Reshape>(add_node->input_value(non_matmul_port), pattern, true);
329330
add_node->input(non_matmul_port).replace_source_output(other_input_reshape->output(0));
330-
ov::copy_runtime_info({in_reshape_0, in_reshape_1}, {first_batch_dim, minus_one, other_input_reshape});
331+
ov::copy_runtime_info({in_reshape_0, in_reshape_1},
332+
{first_batch_dim, minus_one, other_input_reshape, pattern});
331333
add_node->validate_and_infer_types();
332334
}
333335
ov::replace_output_update_name(out_reshape->output(0), out_reshape->input_value(0));

src/common/transformations/src/transformations/transpose_sinking/ts_slice.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,7 @@ TSSliceForward::TSSliceForward() {
4343
transpose_axis_order);
4444
const auto& indices = main_node->input_value(4);
4545
auto new_axis = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
46+
ov::copy_runtime_info(indices.get_node_shared_ptr(), new_axis);
4647

4748
main_node->input(4).replace_source_output(new_axis);
4849

@@ -96,6 +97,7 @@ TSSliceBackward::TSSliceBackward() {
9697
reversed_transpose_order);
9798
const auto& indices = main_node->input_value(4);
9899
auto new_axis = std::make_shared<ov::op::v8::Gather>(data, indices, axis);
100+
ov::copy_runtime_info(indices.get_node_shared_ptr(), new_axis);
99101
main_node->input(4).replace_source_output(new_axis);
100102

101103
main_node->validate_and_infer_types();

src/core/src/bound_evaluate.cpp

-31
Original file line numberDiff line numberDiff line change
@@ -18,36 +18,6 @@
1818
namespace {
1919
using namespace ov;
2020

21-
void propagate_rt_info(Node* node, const Output<Node>& final_port) {
22-
auto node_outputs = node->outputs();
23-
bool same_outputs = std::all_of(node_outputs.begin(), node_outputs.end(), [](const Output<Node>& output) {
24-
return output.get_tensor().has_and_set_bound();
25-
});
26-
if (same_outputs && op::util::is_constant(node)) // constant should not propagate it's rt_info
27-
{
28-
std::unordered_set<Node*> stop_nodes;
29-
for (const auto& in : final_port.get_target_inputs())
30-
stop_nodes.insert(in.get_node());
31-
32-
auto curr_node = node->shared_from_this();
33-
for (const auto& output : node_outputs) {
34-
if (output == final_port)
35-
continue;
36-
for (auto& in : output.get_target_inputs()) {
37-
if (stop_nodes.count(in.get_node()))
38-
continue;
39-
try {
40-
auto consumer = in.get_node()->shared_from_this();
41-
copy_runtime_info({curr_node, consumer}, consumer);
42-
} catch (const std::bad_weak_ptr&) {
43-
// Exception can be thrown, if `shared_from_this()` was called during node creation.
44-
// Continue propagation for other nodes.
45-
}
46-
}
47-
}
48-
}
49-
}
50-
5121
bool are_same_tensor(const ov::Tensor& lhs, const ov::Tensor& rhs) {
5222
return (lhs && rhs) && (lhs.get_element_type() == rhs.get_element_type()) && (lhs.get_shape() == rhs.get_shape()) &&
5323
(lhs.data() == rhs.data());
@@ -287,7 +257,6 @@ void evaluate_bound(const Output<Node>& output) {
287257
}
288258
bound_evaluator.set_bounds_and_symbols();
289259
invalidate_unused_values(node->input_values());
290-
propagate_rt_info(node, output);
291260
}
292261
}
293262
}

0 commit comments

Comments
 (0)