2
2
// SPDX-License-Identifier: Apache-2.0
3
3
//
4
4
5
+ #include " common_test_utils/ov_tensor_utils.hpp"
5
6
#include " common_test_utils/test_common.hpp"
6
7
#include " common_test_utils/common_utils.hpp"
7
8
#include " common_test_utils/node_builders/activation.hpp"
@@ -238,4 +239,39 @@ TEST(VariablesTest, smoke_canSetStateTensor) {
238
239
239
240
ASSERT_NO_THROW (request.infer ());
240
241
}
241
- } // namespace
242
+
243
+ TEST (VariablesTest, smoke_set_get_state_with_convert) {
244
+ auto build_model = [](ov::element::Type type, const ov::PartialShape& shape) {
245
+ auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
246
+ const ov::op::util::VariableInfo variable_info { shape, type, " v0" };
247
+ auto variable = std::make_shared<ov::op::util::Variable>(variable_info);
248
+ auto read_value = std::make_shared<ov::op::v6::ReadValue>(param, variable);
249
+ auto add = std::make_shared<ov::op::v1::Add>(read_value, param);
250
+ auto assign = std::make_shared<ov::op::v6::Assign>(add, variable);
251
+ auto res = std::make_shared<ov::op::v0::Result>(add);
252
+ return std::make_shared<ov::Model>(ov::ResultVector { res }, ov::SinkVector { assign }, ov::ParameterVector{param}, " StateTestModel" );
253
+ };
254
+
255
+ auto ov = ov::Core ();
256
+ const ov::Shape virable_shape = {1 , 3 , 2 , 4 };
257
+ const ov::Shape input_shape = {1 , 3 , 2 , 4 };
258
+ const ov::element::Type et = ov::element::f32;
259
+ auto model = build_model (et, input_shape);
260
+ auto compiled_model = ov.compile_model (model, ov::test::utils::DEVICE_GPU, ov::hint::inference_precision (ov::element::f16));
261
+ auto request = compiled_model.create_infer_request ();
262
+
263
+ auto variables = request.query_state ();
264
+ ASSERT_EQ (variables.size (), 1 );
265
+ auto variable = variables.front ();
266
+ ASSERT_EQ (variable.get_name (), " v0" );
267
+ auto state_tensor = variable.get_state ();
268
+ ASSERT_EQ (state_tensor.get_shape (), virable_shape);
269
+ ASSERT_EQ (state_tensor.get_element_type (), et);
270
+
271
+ auto tensor_to_set = ov::test::utils::create_and_fill_tensor (et, state_tensor.get_shape ());
272
+ variable.set_state (tensor_to_set);
273
+ state_tensor = variable.get_state ();
274
+
275
+ ov::test::utils::compare (tensor_to_set, state_tensor, 1e-5f , 1e-5f );
276
+ }
277
+ } // namespace
0 commit comments