Skip to content

Commit 8cdbedc

Browse files
authored
[TF FE][MOC] Fix leftovers for Keras LSTM fusion transformation (#25268)
**Details:** Fix leftovers for Keras LSTM fusion transformation #25170 **Tickets:** TBD Signed-off-by: Kazantsev, Roman <roman.kazantsev@intel.com>
1 parent 59f1d69 commit 8cdbedc

File tree

2 files changed

+7
-17
lines changed

2 files changed

+7
-17
lines changed

src/common/transformations/src/transformations/op_conversions/convert_ti_to_sequences.cpp

+2-12
Original file line numberDiff line numberDiff line change
@@ -351,20 +351,10 @@ bool check_condition_true_pattern(const std::shared_ptr<op::v0::Result>& cond_re
351351
const auto& condition_map = condition_matcher.get_pattern_value_map();
352352
const auto& cond_const =
353353
ov::as_type_ptr<op::v0::Constant>(condition_map.at(cond_const_label).get_node_shared_ptr());
354-
if (!cond_const) {
354+
bool cond_value = false;
355+
if (!ov::op::util::get_constant_value(cond_const, cond_value) || !cond_value) {
355356
return false;
356357
}
357-
if (ov::shape_size(cond_const->get_shape()) != 1)
358-
return false;
359-
const auto& type = cond_const->get_output_element_type(0);
360-
if (type != ov::element::boolean) {
361-
return false;
362-
}
363-
bool cond_value = cond_const->cast_vector<bool>()[0];
364-
if (!cond_value) {
365-
return false;
366-
}
367-
368358
// number of iteration is retrieve from the first input port
369359
num_iters_output = loop->input_value(0);
370360

src/frontends/tensorflow_common/src/helper_transforms/tensor_list_ops_resolver.cpp

+5-5
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ uint64_t get_new_param_idx(const std::vector<uint64_t>& remove_parameter_idxs, u
172172
for (auto remove_idx : remove_parameter_idxs) {
173173
FRONT_END_GENERAL_CHECK(old_idx != remove_idx,
174174
"[TensorFlow Frontend] internal error: incorrect old_idx for "
175-
"TensorListSliceInputAndConcatOutputReplacer transformation");
175+
"TensorListInLoopOptimization transformation");
176176
if (remove_idx < old_idx) {
177177
++num_removed;
178178
}
@@ -181,7 +181,7 @@ uint64_t get_new_param_idx(const std::vector<uint64_t>& remove_parameter_idxs, u
181181
// compute shifted index
182182
FRONT_END_GENERAL_CHECK(num_removed <= old_idx,
183183
"[TensorFlow Frontend] internal error: incorrect new parameter index computation "
184-
"TensorListSliceInputAndConcatOutputReplacer transformation");
184+
"TensorListInLoopOptimization transformation");
185185
return old_idx - num_removed;
186186
}
187187
} // namespace
@@ -478,7 +478,7 @@ ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOp
478478
std::dynamic_pointer_cast<TensorListSetItem>(body_result->get_input_node_shared_ptr(0));
479479
FRONT_END_GENERAL_CHECK(tensor_list_set_item,
480480
"[TensorFlow Frontend] internal error: tensor_list_set_item is nullptr in "
481-
"TensorListSliceInputAndConcatOutputReplacer");
481+
"TensorListInLoopOptimization");
482482
// unsqueeze newly generated data at this iteration
483483
// that will be concatenated
484484
auto new_data = tensor_list_set_item->input_value(2);
@@ -501,13 +501,13 @@ ov::frontend::tensorflow::pass::TensorListInLoopOptimization::TensorListInLoopOp
501501
const auto& body_param = body_params[param_idx];
502502
FRONT_END_GENERAL_CHECK(body_param->get_output_target_inputs(0).size() == 1,
503503
"[TensorFlow Frontend] internal error: tensor list must have only consumer "
504-
"TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer");
504+
"TensorListGetItem operation in TensorListInLoopOptimization");
505505
auto target_input = *(body_param->get_output_target_inputs(0).begin());
506506
auto tensor_list_get_item =
507507
std::dynamic_pointer_cast<TensorListGetItem>(target_input.get_node()->shared_from_this());
508508
FRONT_END_GENERAL_CHECK(tensor_list_get_item,
509509
"[TensorFlow Frontend] internal error: tensor list must have only consumer "
510-
"TensorListGetItem operation in TensorListSliceInputAndConcatOutputReplacer");
510+
"TensorListGetItem operation in TensorListInLoopOptimization");
511511

512512
auto new_shape = body_param->get_output_partial_shape(0);
513513
if (new_shape.rank().is_static() && new_shape.rank().get_length() > 0) {

0 commit comments

Comments
 (0)