15
15
#include " intel_gpu/runtime/compilation_context.hpp"
16
16
#include " gemm_inst.h"
17
17
#include " permute_inst.h"
18
+ #include " layout_optimizer.h"
18
19
19
20
#include < cstddef>
20
21
#include < vector>
@@ -625,7 +626,7 @@ class gemm_gpu_tests: public ::testing::Test {
625
626
topology topology;
626
627
topology.add (input_layout (" input1" , in1_layout),
627
628
input_layout (" input2" , in2_layout),
628
- gemm (" gemm_ref" , { input_info (" input1" ), input_info (" input2" ) }, data_types::f16,
629
+ gemm (" gemm_ref" , { input_info (" input1" ), input_info (" input2" ) }, data_types::f16,
629
630
{0 , 2 , 1 , 3 }, {0 , 2 , 3 , 1 }, {0 , 1 , 2 , 3 })
630
631
);
631
632
@@ -652,7 +653,7 @@ class gemm_gpu_tests: public ::testing::Test {
652
653
topology topology;
653
654
topology.add (input_layout (" input1" , in1_layout),
654
655
input_layout (" input2" , in2_layout),
655
- gemm (" gemm" , { input_info (" input1" ), input_info (" input2" ) }, data_types::f16,
656
+ gemm (" gemm" , { input_info (" input1" ), input_info (" input2" ) }, data_types::f16,
656
657
{0 , 2 , 1 , 3 }, {0 , 2 , 3 , 1 }, {0 , 1 , 2 , 3 })
657
658
);
658
659
@@ -2789,7 +2790,7 @@ INSTANTIATE_TEST_SUITE_P(gemm_gpu, gemm_onednn_ndims, ::testing::ValuesIn(std::v
2789
2790
2790
2791
class gemm_onednn : public ::testing::Test {
2791
2792
public:
2792
- void test_impl_replacement_with_cldnn () {
2793
+ void test_impl_replacement_with_cldnn (bool is_caching_test ) {
2793
2794
auto & engine = get_test_engine ();
2794
2795
2795
2796
if (!engine.get_device_info ().supports_immad )
@@ -2828,16 +2829,34 @@ class gemm_onednn: public ::testing::Test {
2828
2829
ov::intel_gpu::optimize_data (true ),
2829
2830
ov::intel_gpu::allow_new_shape_infer (true ) };
2830
2831
2831
- network network (engine, topology, cfg);
2832
- network.set_input_data (" input1" , input1);
2833
- network.set_input_data (" input2" , input2);
2832
+ cldnn::network::ptr network;
2833
+ if (is_caching_test) {
2834
+ membuf mem_buf;
2835
+ {
2836
+ std::ostream out_mem (&mem_buf);
2837
+ BinaryOutputBuffer ob = BinaryOutputBuffer (out_mem);
2838
+ ob.set_stream (get_test_stream_ptr ().get ());
2839
+ program::build_program (engine, topology, cfg)->save (ob);
2840
+ }
2841
+ {
2842
+ std::istream in_mem (&mem_buf);
2843
+ BinaryInputBuffer ib = BinaryInputBuffer (in_mem, engine);
2844
+ auto imported_prog = std::make_shared<cldnn::program>(engine, cfg);
2845
+ imported_prog->load (ib);
2846
+ network = std::make_shared<cldnn::network>(imported_prog);
2847
+ }
2848
+ } else {
2849
+ network = std::make_shared<cldnn::network>(engine, topology, cfg);
2850
+ }
2851
+ network->set_input_data (" input1" , input1);
2852
+ network->set_input_data (" input2" , input2);
2834
2853
2835
- auto inst = network. get_primitive (" gemm" );
2854
+ auto inst = network-> get_primitive (" gemm" );
2836
2855
auto impl = inst->get_impl ();
2837
2856
ASSERT_TRUE (impl != nullptr );
2838
2857
ASSERT_TRUE (impl->is_dynamic ());
2839
2858
2840
- auto outputs = network. execute ();
2859
+ auto outputs = network-> execute ();
2841
2860
2842
2861
auto output = outputs.at (" gemm" ).get_memory ();
2843
2862
cldnn::mem_lock<ov::float16> output_ptr (output, get_test_stream ());
@@ -2847,12 +2866,15 @@ class gemm_onednn: public ::testing::Test {
2847
2866
ASSERT_FLOAT_EQ (output_ptr[i], out_data[i]);
2848
2867
}
2849
2868
2850
- // WA: Call wait_all() to wait for all queued kernels compilation finish
2851
- network.get_program ()->get_compilation_context ().wait_all ();
2869
+ // Call wait_all() to wait for all queued kernels compilation finish
2870
+ network->get_program ()->get_compilation_context ().wait_all ();
2871
+
2872
+ auto & lo = network->get_program ()->get_layout_optimizer ();
2873
+ ASSERT_TRUE (lo.get_optimization_attributes ().use_onednn_impls );
2852
2874
2853
2875
// Check if OneDNN's impl is used for the next execute() call
2854
- network. execute ();
2855
- inst = network. get_primitive (" gemm" );
2876
+ network-> execute ();
2877
+ inst = network-> get_primitive (" gemm" );
2856
2878
impl = inst->get_impl ();
2857
2879
ASSERT_TRUE (impl != nullptr );
2858
2880
ASSERT_FALSE (impl->is_dynamic ());
@@ -3214,7 +3236,10 @@ class gemm_onednn: public ::testing::Test {
3214
3236
};
3215
3237
3216
3238
TEST_F (gemm_onednn, impl_replacement_with_cldnn) {
3217
- this ->test_impl_replacement_with_cldnn ();
3239
+ this ->test_impl_replacement_with_cldnn (false );
3240
+ }
3241
+ TEST_F (gemm_onednn, impl_replacement_with_cldnn_cached) {
3242
+ this ->test_impl_replacement_with_cldnn (true );
3218
3243
}
3219
3244
3220
3245
// Check gemm_onednn transpose_format() can accept transpose white list format (byfx/bxfy)
0 commit comments