Skip to content

Commit 0f00606

Browse files
Add cpu_impl and gpu_impl selection condition and pass func test
1 parent 6f4b86e commit 0f00606

File tree

8 files changed

+176
-50
lines changed

8 files changed

+176
-50
lines changed

src/plugins/intel_gpu/src/graph/impls/cpu/resample.cpp

+36-4
Original file line numberDiff line numberDiff line change
@@ -189,19 +189,51 @@ struct resample_impl : public typed_primitive_impl<resample> {
189189
namespace detail {
190190

191191
attach_resample_impl::attach_resample_impl() {
192-
auto formats = {
192+
// auto formats = {
193+
// format::bfyx,
194+
// };
195+
196+
// auto types = {
197+
// data_types::f32,
198+
// };
199+
200+
// implementation_map<resample>::add(impl_types::cpu, shape_types::static_shape, resample_impl::create, types, formats);
201+
// implementation_map<resample>::add(impl_types::cpu, shape_types::dynamic_shape, resample_impl::create, types, formats);
202+
203+
//std::set<implementation_map<resample>::key_type> keys;
204+
205+
const auto types = {data_types::f32, data_types::i32};
206+
const auto formats = {
193207
format::bfyx,
208+
format::b_fs_yx_fsv16,
209+
format::b_fs_yx_fsv32,
210+
format::bs_fs_yx_bsv16_fsv16,
211+
format::bs_fs_yx_bsv32_fsv16,
212+
format::bs_fs_yx_bsv32_fsv32,
213+
214+
format::bfzyx,
215+
format::b_fs_zyx_fsv16,
216+
format::b_fs_zyx_fsv32,
217+
format::bs_fs_zyx_bsv16_fsv32,
218+
format::bs_fs_zyx_bsv16_fsv16,
219+
format::bs_fs_zyx_bsv32_fsv32,
220+
format::bs_fs_zyx_bsv32_fsv16,
194221
};
222+
// for (const auto type : types) {
223+
// for (const auto format : formats) {
224+
// keys.emplace(type, format);
225+
// }
226+
// }
195227

196-
auto types = {
197-
data_types::f32,
198-
};
228+
// keys.emplace(data_types::f32, format::yxfb);
199229

200230
implementation_map<resample>::add(impl_types::cpu, shape_types::static_shape, resample_impl::create, types, formats);
201231
implementation_map<resample>::add(impl_types::cpu, shape_types::dynamic_shape, resample_impl::create, types, formats);
202232
}
203233

204234
} // namespace detail
235+
236+
205237
} // namespace cpu
206238
} // namespace cldnn
207239

src/plugins/intel_gpu/src/graph/impls/ocl/register.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -72,7 +72,7 @@ void register_implementations() {
7272
REGISTER_OCL(swiglu);
7373
REGISTER_OCL(tile);
7474
REGISTER_OCL(gather_tree);
75-
REGISTER_OCL(resample);
75+
//REGISTER_OCL(resample);
7676
REGISTER_OCL(grn);
7777
REGISTER_OCL(ctc_greedy_decoder);
7878
REGISTER_OCL(ctc_loss);

src/plugins/intel_gpu/src/graph/impls/ocl/register.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@
4747
#include "intel_gpu/primitives/reduce.hpp"
4848
#include "intel_gpu/primitives/region_yolo.hpp"
4949
#include "intel_gpu/primitives/reorg_yolo.hpp"
50-
#include "intel_gpu/primitives/resample.hpp"
50+
//#include "intel_gpu/primitives/resample.hpp"
5151
#include "intel_gpu/primitives/reshape.hpp"
5252
#include "intel_gpu/primitives/reverse_sequence.hpp"
5353
#include "intel_gpu/primitives/rms.hpp"

src/plugins/intel_gpu/src/graph/impls/ocl/resample.cpp

+40-34
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
#include "primitive_base.hpp"
66

7+
#include "resample.hpp"
78
#include "resample_inst.h"
89
#include "kernel_selector/kernels/resample/resample_kernel_selector.h"
910
#include "kernel_selector/kernels/resample/resample_kernel_base.h"
@@ -174,42 +175,47 @@ struct resample_impl : typed_primitive_impl_ocl<resample> {
174175
}
175176
};
176177

177-
namespace detail {
178-
179-
attach_resample_impl::attach_resample_impl() {
180-
std::set<implementation_map<resample>::key_type> keys;
181-
182-
const auto types = {data_types::f16, data_types::f32, data_types::i8, data_types::u8, data_types::i32};
183-
const auto formats = {
184-
format::bfyx,
185-
format::b_fs_yx_fsv16,
186-
format::b_fs_yx_fsv32,
187-
format::bs_fs_yx_bsv16_fsv16,
188-
format::bs_fs_yx_bsv32_fsv16,
189-
format::bs_fs_yx_bsv32_fsv32,
190-
191-
format::bfzyx,
192-
format::b_fs_zyx_fsv16,
193-
format::b_fs_zyx_fsv32,
194-
format::bs_fs_zyx_bsv16_fsv32,
195-
format::bs_fs_zyx_bsv16_fsv16,
196-
format::bs_fs_zyx_bsv32_fsv32,
197-
format::bs_fs_zyx_bsv32_fsv16,
198-
};
199-
for (const auto type : types) {
200-
for (const auto format : formats) {
201-
keys.emplace(type, format);
202-
}
203-
}
204-
205-
keys.emplace(data_types::f32, format::yxfb);
206-
keys.emplace(data_types::f16, format::yxfb);
207-
keys.emplace(data_types::f16, format::fs_b_yx_fsv32);
208-
209-
implementation_map<resample>::add(impl_types::ocl, typed_primitive_impl_ocl<resample>::create<resample_impl>, keys);
178+
// namespace detail {
179+
180+
// attach_resample_impl::attach_resample_impl() {
181+
// std::set<implementation_map<resample>::key_type> keys;
182+
183+
// const auto types = {data_types::f16, data_types::f32, data_types::i8, data_types::u8, data_types::i32};
184+
// const auto formats = {
185+
// format::bfyx,
186+
// format::b_fs_yx_fsv16,
187+
// format::b_fs_yx_fsv32,
188+
// format::bs_fs_yx_bsv16_fsv16,
189+
// format::bs_fs_yx_bsv32_fsv16,
190+
// format::bs_fs_yx_bsv32_fsv32,
191+
192+
// format::bfzyx,
193+
// format::b_fs_zyx_fsv16,
194+
// format::b_fs_zyx_fsv32,
195+
// format::bs_fs_zyx_bsv16_fsv32,
196+
// format::bs_fs_zyx_bsv16_fsv16,
197+
// format::bs_fs_zyx_bsv32_fsv32,
198+
// format::bs_fs_zyx_bsv32_fsv16,
199+
// };
200+
// for (const auto type : types) {
201+
// for (const auto format : formats) {
202+
// keys.emplace(type, format);
203+
// }
204+
// }
205+
206+
// keys.emplace(data_types::f32, format::yxfb);
207+
// keys.emplace(data_types::f16, format::yxfb);
208+
// keys.emplace(data_types::f16, format::fs_b_yx_fsv32);
209+
210+
// implementation_map<resample>::add(impl_types::ocl, typed_primitive_impl_ocl<resample>::create<resample_impl>, keys);
211+
// }
212+
213+
// } // namespace detail
214+
std::unique_ptr<primitive_impl> ResampleImplementationManager::create_impl(const program_node& node, const kernel_impl_params& params) const {
215+
assert(node.is_type<resample>());
216+
return typed_primitive_impl_ocl<resample>::create<resample_impl>(static_cast<const resample_node&>(node), params);
210217
}
211218

212-
} // namespace detail
213219
} // namespace ocl
214220
} // namespace cldnn
215221

Original file line numberDiff line numberDiff line change
@@ -0,0 +1,33 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "impls/registry/implementation_manager.hpp"
6+
#include "program_node.h"
7+
//#include "intel_gpu/primitives/resample.hpp"
8+
9+
#include <memory>
10+
namespace cldnn {
11+
namespace ocl {
12+
13+
struct ResampleImplementationManager : public ImplementationManager {
14+
OV_GPU_PRIMITIVE_IMPL("ocl::resample")
15+
ResampleImplementationManager(shape_types shape_type, ValidateFunc vf = nullptr) : ImplementationManager(impl_types::ocl, shape_type, vf) {}
16+
std::unique_ptr<primitive_impl> create_impl(const program_node& node, const kernel_impl_params& params) const override;
17+
bool validate_impl(const program_node& node) const override {
18+
// auto prim = node.as<resample>().get_primitive();
19+
// const auto& in0_layout = node.get_input_layout(0);
20+
21+
// if (in0_layout.data_type == ov::element::f32 &&
22+
// prim->operation_type == ov::op::util::InterpolateBase::InterpolateMode::LINEAR_ONNX &&
23+
// prim->coord_trans_mode == ov::op::util::InterpolateBase::CoordinateTransformMode::ALIGN_CORNERS &&
24+
// prim->shape_calc_mode == ov::op::util::InterpolateBase::ShapeCalcMode::SCALES) {
25+
// return false;
26+
// }
27+
28+
return true;
29+
}
30+
};
31+
32+
} // namespace ocl
33+
} // namespace cldnn

src/plugins/intel_gpu/src/graph/impls/registry/registry.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -139,7 +139,7 @@ REGISTER_IMPLS(reshape);
139139
REGISTER_IMPLS(non_max_suppression);
140140
REGISTER_IMPLS(softmax);
141141
REGISTER_IMPLS(range);
142-
//REGISTER_IMPLS(resample);
142+
REGISTER_IMPLS(resample);
143143
REGISTER_IMPLS(select);
144144
REGISTER_IMPLS(scatter_update);
145145
REGISTER_IMPLS(scatter_elements_update);
@@ -201,7 +201,7 @@ REGISTER_DEFAULT_IMPLS(space_to_batch, OCL_S);
201201
REGISTER_DEFAULT_IMPLS(space_to_depth, OCL_S);
202202
REGISTER_DEFAULT_IMPLS(swiglu, OCL_S, OCL_D);
203203
REGISTER_DEFAULT_IMPLS(gather_tree, OCL_S);
204-
REGISTER_DEFAULT_IMPLS(resample, CPU_S, OCL_S);
204+
//REGISTER_DEFAULT_IMPLS(resample, CPU_S, OCL_S);
205205
REGISTER_DEFAULT_IMPLS(grn, OCL_S);
206206
REGISTER_DEFAULT_IMPLS(ctc_greedy_decoder, OCL_S);
207207
REGISTER_DEFAULT_IMPLS(ctc_loss, OCL_S);
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,55 @@
1+
// Copyright (C) 2024 Intel Corporation
2+
// SPDX-License-Identifier: Apache-2.0
3+
//
4+
5+
#include "predicates.hpp"
6+
#include "registry.hpp"
7+
#include "intel_gpu/primitives/resample.hpp"
8+
#include "primitive_inst.h"
9+
10+
#if OV_GPU_WITH_OCL
11+
#include "impls/ocl/resample.hpp"
12+
#endif
13+
14+
15+
namespace ov {
16+
namespace intel_gpu {
17+
18+
using namespace cldnn;
19+
20+
const std::vector<std::shared_ptr<cldnn::ImplementationManager>>& Registry<resample>::get_implementations() {
21+
static const std::vector<std::shared_ptr<ImplementationManager>> impls = {
22+
OV_GPU_CREATE_INSTANCE_OCL(ocl::ResampleImplementationManager, shape_types::static_shape,
23+
[](const cldnn::program_node& node){
24+
auto prim = node.as<resample>().get_primitive();
25+
const auto& in0_layout = node.get_input_layout(0);
26+
27+
if (in0_layout.data_type == ov::element::f32 &&
28+
prim->operation_type == ov::op::util::InterpolateBase::InterpolateMode::LINEAR_ONNX &&
29+
prim->coord_trans_mode == ov::op::util::InterpolateBase::CoordinateTransformMode::ALIGN_CORNERS &&
30+
prim->shape_calc_mode == ov::op::util::InterpolateBase::ShapeCalcMode::SCALES) {
31+
return false;
32+
}
33+
34+
return true;
35+
})
36+
// OV_GPU_CREATE_INSTANCE_OCL(ocl::ResampleImplementationManager, shape_types::dynamic_shape,
37+
// [](const cldnn::program_node& node){
38+
// return false;
39+
// })
40+
41+
OV_GPU_GET_INSTANCE_CPU(resample, shape_types::static_shape,
42+
[](const cldnn::program_node& node){
43+
return true;
44+
})
45+
// OV_GPU_GET_INSTANCE_CPU(resample, shape_types::dynamic_shape,
46+
// [](const cldnn::program_node& node){
47+
// return false;
48+
// })
49+
};
50+
51+
return impls;
52+
}
53+
54+
} // namespace intel_gpu
55+
} // namespace ov

src/plugins/intel_gpu/tests/functional/single_layer_tests/dynamic/interpolate.cpp

+8-8
Original file line numberDiff line numberDiff line change
@@ -556,14 +556,14 @@ const std::vector<ShapeParams> shapeParams4D_LargeShape = {
556556
{{1.f, 1.f, 2.f, 2.f}},
557557
defaultAxes4D.front()
558558
},
559-
// ShapeParams{
560-
// ov::op::v4::Interpolate::ShapeCalcMode::SIZES,
561-
// InputShape{{-1, -1, -1, -1}, {{1, 3, 48, 48}}},
562-
// ov::test::utils::InputLayerType::CONSTANT,
563-
// ov::test::utils::InputLayerType::CONSTANT,
564-
// {{1, 3, 144, 144}},
565-
// defaultAxes4D.front()
566-
// },
559+
ShapeParams{
560+
ov::op::v4::Interpolate::ShapeCalcMode::SIZES,
561+
InputShape{{-1, -1, -1, -1}, {{1, 3, 48, 48}}},
562+
ov::test::utils::InputLayerType::CONSTANT,
563+
ov::test::utils::InputLayerType::CONSTANT,
564+
{{1, 3, 144, 144}},
565+
defaultAxes4D.front()
566+
},
567567
};
568568

569569
const auto interpolateCasesLinearOnnx_AlignCorners_Floor = ::testing::Combine(

0 commit comments

Comments
 (0)