Skip to content

Commit ac7cb8b

Browse files
[GPU] Save use_onednn attribute in the blob (#27097)
### Details: - This is needed to have correct runtime impl selection of imported model ### Tickets: - *CVS-154891*
1 parent 212be8e commit ac7cb8b

File tree

2 files changed

+42
-13
lines changed

2 files changed

+42
-13
lines changed

src/plugins/intel_gpu/src/graph/program.cpp

+4
Original file line numberDiff line numberDiff line change
@@ -1776,6 +1776,7 @@ void program::save(cldnn::BinaryOutputBuffer& ob) const {
17761776

17771777
ob << _is_body_program;
17781778
ob << _can_be_optimized;
1779+
ob << get_layout_optimizer().get_optimization_attributes().use_onednn_impls;
17791780
processing_order.save(ob);
17801781

17811782
{
@@ -1895,6 +1896,9 @@ void program::load(cldnn::BinaryInputBuffer& ib) {
18951896

18961897
ib >> _is_body_program;
18971898
ib >> _can_be_optimized;
1899+
int32_t use_onednn_attr = 0;
1900+
ib >> use_onednn_attr;
1901+
get_layout_optimizer().set_optimization_attribute(layout_optimizer::optimization_attributes_type::use_onednn_impls, use_onednn_attr);
18981902
_loaded_from_cache = true;
18991903

19001904
processing_order.load(ib, *this);

src/plugins/intel_gpu/tests/unit/test_cases/gemm_gpu_test.cpp

+38-13
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
#include "intel_gpu/runtime/compilation_context.hpp"
1616
#include "gemm_inst.h"
1717
#include "permute_inst.h"
18+
#include "layout_optimizer.h"
1819

1920
#include <cstddef>
2021
#include <vector>
@@ -625,7 +626,7 @@ class gemm_gpu_tests: public ::testing::Test {
625626
topology topology;
626627
topology.add(input_layout("input1", in1_layout),
627628
input_layout("input2", in2_layout),
628-
gemm("gemm_ref", { input_info("input1"), input_info("input2") }, data_types::f16,
629+
gemm("gemm_ref", { input_info("input1"), input_info("input2") }, data_types::f16,
629630
{0, 2, 1, 3}, {0, 2, 3, 1}, {0, 1, 2, 3})
630631
);
631632

@@ -652,7 +653,7 @@ class gemm_gpu_tests: public ::testing::Test {
652653
topology topology;
653654
topology.add(input_layout("input1", in1_layout),
654655
input_layout("input2", in2_layout),
655-
gemm("gemm", { input_info("input1"), input_info("input2") }, data_types::f16,
656+
gemm("gemm", { input_info("input1"), input_info("input2") }, data_types::f16,
656657
{0, 2, 1, 3}, {0, 2, 3, 1}, {0, 1, 2, 3})
657658
);
658659

@@ -2789,7 +2790,7 @@ INSTANTIATE_TEST_SUITE_P(gemm_gpu, gemm_onednn_ndims, ::testing::ValuesIn(std::v
27892790

27902791
class gemm_onednn: public ::testing::Test {
27912792
public:
2792-
void test_impl_replacement_with_cldnn() {
2793+
void test_impl_replacement_with_cldnn(bool is_caching_test) {
27932794
auto& engine = get_test_engine();
27942795

27952796
if (!engine.get_device_info().supports_immad)
@@ -2828,16 +2829,34 @@ class gemm_onednn: public ::testing::Test {
28282829
ov::intel_gpu::optimize_data(true),
28292830
ov::intel_gpu::allow_new_shape_infer(true) };
28302831

2831-
network network(engine, topology, cfg);
2832-
network.set_input_data("input1", input1);
2833-
network.set_input_data("input2", input2);
2832+
cldnn::network::ptr network;
2833+
if (is_caching_test) {
2834+
membuf mem_buf;
2835+
{
2836+
std::ostream out_mem(&mem_buf);
2837+
BinaryOutputBuffer ob = BinaryOutputBuffer(out_mem);
2838+
ob.set_stream(get_test_stream_ptr().get());
2839+
program::build_program(engine, topology, cfg)->save(ob);
2840+
}
2841+
{
2842+
std::istream in_mem(&mem_buf);
2843+
BinaryInputBuffer ib = BinaryInputBuffer(in_mem, engine);
2844+
auto imported_prog = std::make_shared<cldnn::program>(engine, cfg);
2845+
imported_prog->load(ib);
2846+
network = std::make_shared<cldnn::network>(imported_prog);
2847+
}
2848+
} else {
2849+
network = std::make_shared<cldnn::network>(engine, topology, cfg);
2850+
}
2851+
network->set_input_data("input1", input1);
2852+
network->set_input_data("input2", input2);
28342853

2835-
auto inst = network.get_primitive("gemm");
2854+
auto inst = network->get_primitive("gemm");
28362855
auto impl = inst->get_impl();
28372856
ASSERT_TRUE(impl != nullptr);
28382857
ASSERT_TRUE(impl->is_dynamic());
28392858

2840-
auto outputs = network.execute();
2859+
auto outputs = network->execute();
28412860

28422861
auto output = outputs.at("gemm").get_memory();
28432862
cldnn::mem_lock<ov::float16> output_ptr(output, get_test_stream());
@@ -2847,12 +2866,15 @@ class gemm_onednn: public ::testing::Test {
28472866
ASSERT_FLOAT_EQ(output_ptr[i], out_data[i]);
28482867
}
28492868

2850-
// WA: Call wait_all() to wait for all queued kernels compilation finish
2851-
network.get_program()->get_compilation_context().wait_all();
2869+
// Call wait_all() to wait for all queued kernels compilation finish
2870+
network->get_program()->get_compilation_context().wait_all();
2871+
2872+
auto& lo = network->get_program()->get_layout_optimizer();
2873+
ASSERT_TRUE(lo.get_optimization_attributes().use_onednn_impls);
28522874

28532875
// Check if OneDNN's impl is used for the next execute() call
2854-
network.execute();
2855-
inst = network.get_primitive("gemm");
2876+
network->execute();
2877+
inst = network->get_primitive("gemm");
28562878
impl = inst->get_impl();
28572879
ASSERT_TRUE(impl != nullptr);
28582880
ASSERT_FALSE(impl->is_dynamic());
@@ -3214,7 +3236,10 @@ class gemm_onednn: public ::testing::Test {
32143236
};
32153237

32163238
TEST_F(gemm_onednn, impl_replacement_with_cldnn) {
3217-
this->test_impl_replacement_with_cldnn();
3239+
this->test_impl_replacement_with_cldnn(false);
3240+
}
3241+
TEST_F(gemm_onednn, impl_replacement_with_cldnn_cached) {
3242+
this->test_impl_replacement_with_cldnn(true);
32183243
}
32193244

32203245
// Check gemm_onednn transpose_format() can accept transpose white list format (byfx/bxfy)

0 commit comments

Comments
 (0)