Skip to content

Commit 9dad09a

Browse files
author
Vladimir Paramuzov
authored
[GPU] Fix set_state() when precision conversion is needed (#21874)
1 parent b821dcc commit 9dad09a

File tree

4 files changed

+60
-10
lines changed

4 files changed

+60
-10
lines changed

src/plugins/intel_gpu/include/intel_gpu/plugin/common_utils.hpp

+1
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,7 @@ inline void ForceExit() {
103103
std::_Exit(-1);
104104
}
105105

106+
void convert_and_copy(const ov::ITensor* src, cldnn::memory::ptr dst, cldnn::stream& stream);
106107
void convert_and_copy(const cldnn::memory::ptr src, ov::ITensor const* dst, const cldnn::stream& stream);
107108
void convert_and_copy(const ov::ITensor* src, ov::ITensor const* dst, const cldnn::stream& stream);
108109

src/plugins/intel_gpu/src/plugin/common_utils.cpp

+21
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,27 @@ void convert_and_copy(const void* src_ptr, ov::element::Type src_et, void* dst_p
9494
namespace ov {
9595
namespace intel_gpu {
9696

97+
void convert_and_copy(const ov::ITensor* src, cldnn::memory::ptr dst, cldnn::stream& stream) {
98+
const bool blocking = true;
99+
auto src_et = src->get_element_type();
100+
auto dst_et = dst->get_layout().data_type;
101+
102+
if (dst_et == src_et) {
103+
if (auto remote = dynamic_cast<const ov::intel_gpu::RemoteTensorImpl*>(src)) {
104+
auto mem = remote->get_original_memory();
105+
dst->copy_from(stream, *mem, blocking);
106+
} else {
107+
dst->copy_from(stream, src->data(), blocking);
108+
return;
109+
}
110+
}
111+
112+
size_t size = ov::shape_size(src->get_shape());
113+
ov::Tensor tmp_tensor(dst_et, src->get_shape());
114+
::convert_and_copy(src->data(), src_et, tmp_tensor.data(), dst_et, size, cldnn::layout({}, ov::element::undefined, cldnn::format::bfyx, cldnn::padding()));
115+
dst->copy_from(stream, tmp_tensor.data(), blocking);
116+
}
117+
97118
void convert_and_copy(const cldnn::memory::ptr src, ov::ITensor const* dst, const cldnn::stream& stream) {
98119
auto src_et = src->get_layout().data_type;
99120
auto dst_et = dst->get_element_type();

src/plugins/intel_gpu/src/plugin/variable_state.cpp

+1-9
Original file line numberDiff line numberDiff line change
@@ -54,17 +54,9 @@ void VariableState::set_layout(const cldnn::layout& new_layout) {
5454
}
5555

5656
void VariableState::set_state(const ov::SoPtr<ov::ITensor>& state) {
57-
const bool blocking = true;
58-
auto remote_ptr = std::dynamic_pointer_cast<RemoteTensorImpl>(state._ptr);
5957
m_layout.set_partial_shape(state->get_shape());
6058
update_device_buffer();
61-
if (remote_ptr != nullptr) {
62-
auto user_memory = remote_ptr->get_memory();
63-
m_memory->copy_from(m_context->get_engine().get_service_stream(), *user_memory, blocking);
64-
} else {
65-
auto data = state->data();
66-
m_memory->copy_from(m_context->get_engine().get_service_stream(), data, blocking);
67-
}
59+
convert_and_copy(state._ptr.get(), m_memory, m_context->get_engine().get_service_stream());
6860
set();
6961
}
7062

src/plugins/intel_gpu/tests/functional/behavior/infer_request.cpp

+37-1
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
// SPDX-License-Identifier: Apache-2.0
33
//
44

5+
#include "common_test_utils/ov_tensor_utils.hpp"
56
#include "common_test_utils/test_common.hpp"
67
#include "common_test_utils/common_utils.hpp"
78
#include "common_test_utils/node_builders/activation.hpp"
@@ -238,4 +239,39 @@ TEST(VariablesTest, smoke_canSetStateTensor) {
238239

239240
ASSERT_NO_THROW(request.infer());
240241
}
241-
} // namespace
242+
243+
TEST(VariablesTest, smoke_set_get_state_with_convert) {
244+
auto build_model = [](ov::element::Type type, const ov::PartialShape& shape) {
245+
auto param = std::make_shared<ov::op::v0::Parameter>(type, shape);
246+
const ov::op::util::VariableInfo variable_info { shape, type, "v0" };
247+
auto variable = std::make_shared<ov::op::util::Variable>(variable_info);
248+
auto read_value = std::make_shared<ov::op::v6::ReadValue>(param, variable);
249+
auto add = std::make_shared<ov::op::v1::Add>(read_value, param);
250+
auto assign = std::make_shared<ov::op::v6::Assign>(add, variable);
251+
auto res = std::make_shared<ov::op::v0::Result>(add);
252+
return std::make_shared<ov::Model>(ov::ResultVector { res }, ov::SinkVector { assign }, ov::ParameterVector{param}, "StateTestModel");
253+
};
254+
255+
auto ov = ov::Core();
256+
const ov::Shape virable_shape = {1, 3, 2, 4};
257+
const ov::Shape input_shape = {1, 3, 2, 4};
258+
const ov::element::Type et = ov::element::f32;
259+
auto model = build_model(et, input_shape);
260+
auto compiled_model = ov.compile_model(model, ov::test::utils::DEVICE_GPU, ov::hint::inference_precision(ov::element::f16));
261+
auto request = compiled_model.create_infer_request();
262+
263+
auto variables = request.query_state();
264+
ASSERT_EQ(variables.size(), 1);
265+
auto variable = variables.front();
266+
ASSERT_EQ(variable.get_name(), "v0");
267+
auto state_tensor = variable.get_state();
268+
ASSERT_EQ(state_tensor.get_shape(), virable_shape);
269+
ASSERT_EQ(state_tensor.get_element_type(), et);
270+
271+
auto tensor_to_set = ov::test::utils::create_and_fill_tensor(et, state_tensor.get_shape());
272+
variable.set_state(tensor_to_set);
273+
state_tensor = variable.get_state();
274+
275+
ov::test::utils::compare(tensor_to_set, state_tensor, 1e-5f, 1e-5f);
276+
}
277+
} // namespace

0 commit comments

Comments
 (0)