Skip to content

Commit ca501ca

Browse files
authored
[GPU] Use onednn gemm instead of inner product (#27628)
SD3 dynamic has FC unfusion pattern, and needs sub graphs. It caused bad performance. ### Tickets: - *152851, 157100* --------- Signed-off-by: hyunback <hyunback.kim@intel.com>
1 parent 517ad68 commit ca501ca

File tree

6 files changed

+91
-158
lines changed

6 files changed

+91
-158
lines changed

src/plugins/intel_gpu/src/graph/graph_optimizer/reorder_inputs.cpp

+12-2
Original file line numberDiff line numberDiff line change
@@ -1028,6 +1028,15 @@ void reorder_inputs::run(program& p, reorder_factory& rf) {
10281028
if (fc_layout.is_dynamic() || data_layout.is_dynamic())
10291029
continue;
10301030

1031+
auto same_spatial = [](layout a, layout b) {
1032+
if (a.get_spatial_rank() != b.get_spatial_rank())
1033+
return false;
1034+
for (size_t i = 0; i < a.get_spatial_rank(); i++) {
1035+
if (a.spatial(i) != b.spatial(i))
1036+
return false;
1037+
}
1038+
return true;
1039+
};
10311040
// fc_b | fc_f | data_b | data_f | broadcast condition
10321041
// ---------+-----------+-----------+-----------+--------------------
10331042
// 1 | 1 | 1 | 1 | no broadcast
@@ -1043,11 +1052,12 @@ void reorder_inputs::run(program& p, reorder_factory& rf) {
10431052
// N | 1 | N | 1 | no broadcast
10441053
// N | 1 | N | N | N/A
10451054
// N | N | 1 | 1 | implicit broadcast
1046-
// N | N | 1 | N | explicit broadcast
1047-
// N | N | N | 1 | explicit broadcast
1055+
// N | N | 1 | N | explicit broadcast when spatial different
1056+
// N | N | N | 1 | explicit broadcast when spatial different
10481057
// N | N | N | N | no broadcast
10491058
if ((fc_layout.batch() == 1 || fc_layout.feature() == 1) ||
10501059
(data_layout.batch() == 1 && data_layout.feature() == 1) ||
1060+
((data_layout.batch() == 1 || data_layout.feature() == 1) && same_spatial(fc_layout, data_layout)) ||
10511061
(fc_layout.count() == data_layout.count())) {
10521062
continue;
10531063
}

src/plugins/intel_gpu/src/graph/impls/onednn/fully_connected_onednn.cpp

+10-86
Original file line numberDiff line numberDiff line change
@@ -98,44 +98,6 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
9898
return args;
9999
}
100100

101-
static std::shared_ptr<WeightsReorderParams> get_weights_reorder(const kernel_impl_params& impl_params, const dnnl::primitive_desc& pd) {
102-
auto input_layout = impl_params.get_input_layout(0);
103-
auto source_weights_layout = impl_params.get_input_layout(1);
104-
auto cldnn_prim = impl_params.typed_desc<fully_connected>();
105-
106-
auto input_pshape = input_layout.get_partial_shape();
107-
auto weights_pshape = source_weights_layout.get_partial_shape();
108-
109-
int64_t feature = input_pshape[std::min(cldnn_prim->input_size, static_cast<size_t>(4)) - 1].get_length();
110-
if (cldnn_prim->input_size == 3) {
111-
feature = std::max({input_layout.spatial(0), input_layout.spatial(1), input_layout.spatial(2)});
112-
}
113-
auto target_weights_layout = source_weights_layout;
114-
if (weights_pshape.size() != 2) {
115-
target_weights_layout.set_partial_shape(reshape_to_2d(weights_pshape, feature));
116-
}
117-
118-
auto target_weights_desc = pd.weights_desc(0);
119-
120-
auto shape_consistent = onednn::keep_weights_reorder_shape_consistent(source_weights_layout, target_weights_desc);
121-
OPENVINO_ASSERT(shape_consistent, "[GPU] Input shape and output shape of weight reorder should be same.");
122-
123-
auto source_weights_desc = onednn::layout_to_memory_desc(source_weights_layout);
124-
125-
const bool weights_format = true;
126-
const bool grouped = false;
127-
128-
auto traits = convert_memory_desc_to_traits(target_weights_desc, weights_format, grouped);
129-
130-
target_weights_layout.format = format(traits);
131-
132-
return std::make_shared<WeightsReorderParamsOneDNN>(source_weights_layout,
133-
target_weights_layout,
134-
source_weights_desc,
135-
target_weights_desc,
136-
false);
137-
}
138-
139101
static void transform_layouts(layout& input_layout, layout& weights_layout, layout& output_layout, size_t prim_input_size) {
140102
auto input_pshape = input_layout.get_partial_shape();
141103
auto weights_pshape = weights_layout.get_partial_shape();
@@ -164,43 +126,6 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
164126
}
165127
}
166128

167-
static std::shared_ptr<dnnl::inner_product_forward::primitive_desc>
168-
get_inner_product_primitive_descriptor(const kernel_impl_params& impl_params,
169-
cldnn::engine& engine,
170-
size_t prim_input_size,
171-
bool has_bias,
172-
const dnnl::primitive_attr& attr = dnnl::primitive_attr()) {
173-
auto input_layout = impl_params.get_input_layout(0);
174-
auto weights_layout = impl_params.get_input_layout(1);
175-
auto output_layout = impl_params.get_output_layout();
176-
177-
transform_layouts(input_layout, weights_layout, output_layout, prim_input_size);
178-
179-
auto input_md = onednn::layout_to_memory_desc(input_layout, dnnl::memory::format_tag::undef, false);
180-
auto weights_md = onednn::layout_to_memory_desc(weights_layout, dnnl::memory::format_tag::any);
181-
auto output_md = onednn::layout_to_memory_desc(output_layout, dnnl::memory::format_tag::ab, false);
182-
183-
if (has_bias) {
184-
auto bias_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(2), dnnl::memory::format_tag::any, true);
185-
return std::make_shared<dnnl::inner_product_forward::primitive_desc>(
186-
engine.get_onednn_engine(),
187-
dnnl::prop_kind::forward_inference,
188-
input_md,
189-
weights_md,
190-
bias_md,
191-
output_md,
192-
attr);
193-
} else {
194-
return std::make_shared<dnnl::inner_product_forward::primitive_desc>(
195-
engine.get_onednn_engine(),
196-
dnnl::prop_kind::forward_inference,
197-
input_md,
198-
weights_md,
199-
output_md,
200-
attr);
201-
}
202-
}
203-
204129
static std::shared_ptr<dnnl::matmul::primitive_desc>
205130
get_matmul_primitive_descriptor(const kernel_impl_params& impl_params,
206131
cldnn::engine& engine,
@@ -219,7 +144,11 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
219144
auto output_md = onednn::layout_to_memory_desc(output_layout, dnnl::memory::format_tag::ab, false);
220145

221146
if (has_bias) {
222-
auto bias_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(2), dnnl::memory::format_tag::ab, false);
147+
dnnl::memory::format_tag target_fmt = dnnl::memory::format_tag::ab;
148+
auto bias_l = impl_params.get_input_layout(2);
149+
if (bias_l.get_shape().size() == 1)
150+
target_fmt = dnnl::memory::format_tag::ba;
151+
auto bias_md = onednn::layout_to_memory_desc(impl_params.get_input_layout(2), target_fmt, false);
223152
return std::make_shared<dnnl::matmul::primitive_desc>(
224153
engine.get_onednn_engine(),
225154
input_md,
@@ -335,13 +264,8 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
335264
_attrs->set_zero_points(DNNL_ARG_SRC, GROUPED, dnnl::memory::dims{1, src_group_size}, dnnl::memory::data_type::u8);
336265
}
337266

338-
if (is_compressed) {
339-
auto prim_desc = get_matmul_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
340-
_pd = *prim_desc;
341-
} else {
342-
auto prim_desc = get_inner_product_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
343-
_pd = *prim_desc;
344-
}
267+
auto prim_desc = get_matmul_primitive_descriptor(*impl_params, ib.get_engine(), input_size, has_bias, *_attrs);
268+
_pd = *prim_desc;
345269

346270
std::vector<uint8_t> prim_cache;
347271
ib >> prim_cache;
@@ -426,10 +350,10 @@ struct fully_connected_onednn : typed_primitive_onednn_impl<fully_connected> {
426350
prim_onednn->_dzp_data_type = dzp_data_type;
427351
return prim_onednn;
428352
} else {
429-
auto prim_desc = get_inner_product_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
430-
prim->input_size, !prim->bias.empty(), *attr);
353+
auto prim_desc = get_matmul_primitive_descriptor(impl_params, impl_params.prog->get_engine(),
354+
prim->input_size, !prim->bias.empty(), *attr);
431355

432-
return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc, get_weights_reorder(impl_params, *prim_desc));
356+
return cldnn::make_unique<fully_connected_onednn>(engine, config, attr, *prim_desc);
433357
}
434358
}
435359
};

src/plugins/intel_gpu/src/graph/primitive_inst.cpp

+11-1
Original file line numberDiff line numberDiff line change
@@ -2671,7 +2671,6 @@ bool primitive_inst::is_valid_fusion() const {
26712671
auto gemm_dims = onednn::convert_gemm_tensor(gemm_layout.get_tensor(),
26722672
cldnn::format::dimension(gemm_layout.format),
26732673
false);
2674-
26752674
auto data_dims = onednn::convert_gemm_tensor(data_layout.get_tensor(),
26762675
cldnn::format::dimension(data_layout.format),
26772676
false);
@@ -2685,8 +2684,19 @@ bool primitive_inst::is_valid_fusion() const {
26852684
const auto fc_dims = fc_layout.get_dims();
26862685
const auto data_dims = data_layout.get_dims();
26872686

2687+
auto same_spatial = [](layout a, layout b) {
2688+
if (a.get_spatial_rank() != b.get_spatial_rank())
2689+
return false;
2690+
for (size_t i = 0; i < a.get_spatial_rank(); i++) {
2691+
if (a.spatial(i) != b.spatial(i))
2692+
return false;
2693+
}
2694+
return true;
2695+
};
2696+
26882697
if (!(fc_dims[0] == 1 || fc_dims[1] == 1) &&
26892698
!(data_dims[0] == 1 && data_dims[1] == 1) &&
2699+
!((data_dims[0] == 1 || data_dims[1] == 1) && same_spatial(fc_layout, data_layout)) &&
26902700
!(fc_layout.count() == data_layout.count())) {
26912701
return false;
26922702
}

src/plugins/intel_gpu/src/graph/program_node.cpp

+52-20
Original file line numberDiff line numberDiff line change
@@ -1558,30 +1558,62 @@ void program_node::create_onednn_primitive_attributes(
15581558
} else if (desc.is_type<eltwise>()) {
15591559
auto dep_idx = desc.outer_dep_start_idx;
15601560
auto in = get_input_layout(dep_idx);
1561-
auto in_origin = in;
1562-
auto set_binary_op = [&](dnnl::algorithm alg, onednn_post_op_type op_type) {
1563-
if (is_type<fully_connected>()) {
1564-
auto prim = this->as<fully_connected>().get_primitive();
1565-
if (prim->input_size == 3) {
1566-
cldnn::onednn::combine_bf_with_first_spatial_dim(in);
1561+
auto fc_needs_full_tensor = [&]() {
1562+
for (size_t i = 0; i < cldnn_post_ops.size(); i++) {
1563+
auto& desc = cldnn_post_ops[i];
1564+
if (desc.is_type<eltwise>()) {
1565+
auto prim = this->as<fully_connected>().get_primitive();
1566+
auto dep_idx = desc.outer_dep_start_idx;
1567+
auto in = get_input_layout(dep_idx);
1568+
if (prim->input_size == 3 && in.batch() > 1 && in.feature() > 1)
1569+
return true;
15671570
}
1568-
auto mem_desc = onednn::layout_to_memory_desc(in, dnnl::memory::format_tag::ab);
1569-
post_ops.append_binary(alg, mem_desc);
1570-
update_onednn_post_op_list(op_type, dep_idx, dnnl::memory::format_tag::ab, false,
1571-
mem_desc.get_dims(), mem_desc.get_data_type());
1572-
} else if (is_type<gemm>()) {
1571+
}
1572+
return false;
1573+
};
1574+
auto set_binary_op = [&](dnnl::algorithm alg, onednn_post_op_type op_type) {
1575+
if (is_type<fully_connected>() || is_type<gemm>()) {
15731576
size_t rank = cldnn::format::dimension(in.format);
15741577
auto in_pshape = in.get_partial_shape();
15751578
auto out_pshape = get_output_layout().get_partial_shape();
1576-
size_t ones_to_add = std::max(out_pshape.size(), static_cast<size_t>(rank)) - in_pshape.size();
1577-
if (ones_to_add > 0) {
1578-
layout new_layout = in;
1579-
ov::PartialShape new_input_pshape;
1580-
std::vector<ov::Dimension> dims(in_pshape.begin(), in_pshape.begin() + in_pshape.size());
1581-
new_input_pshape = ov::PartialShape(dims);
1582-
new_input_pshape.insert(new_input_pshape.begin(), ones_to_add, 1ul);
1583-
new_layout.set_partial_shape(new_input_pshape);
1584-
in = new_layout;
1579+
size_t ones_to_add = 0;
1580+
1581+
if (is_type<fully_connected>()) {
1582+
auto prim = this->as<fully_connected>().get_primitive();
1583+
if (prim->input_size == in_pshape.size()) {
1584+
if (prim->input_size == 3 && !fc_needs_full_tensor()) {
1585+
cldnn::onednn::combine_bf_with_first_spatial_dim(in);
1586+
in_pshape = in.get_partial_shape();
1587+
}
1588+
ones_to_add = std::max(out_pshape.size(), static_cast<size_t>(rank)) - in_pshape.size();
1589+
} else {
1590+
if (prim->input_size == 3)
1591+
cldnn::onednn::combine_bf_with_first_spatial_dim(in);
1592+
ones_to_add = std::max(in_pshape.size(), prim->input_size) - std::min(in_pshape.size(), prim->input_size);
1593+
}
1594+
if (ones_to_add > 0) {
1595+
layout new_layout = in;
1596+
ov::PartialShape new_input_pshape;
1597+
auto last = in_pshape.begin() + in_pshape.size();
1598+
if (in_pshape.size() > prim->input_size)
1599+
last -= ones_to_add;
1600+
std::vector<ov::Dimension> dims(in_pshape.begin(), last);
1601+
new_input_pshape = ov::PartialShape(dims);
1602+
new_input_pshape.insert(new_input_pshape.begin(), ones_to_add, 1ul);
1603+
new_layout.set_partial_shape(new_input_pshape);
1604+
in = new_layout;
1605+
}
1606+
} else {
1607+
ones_to_add = std::max(out_pshape.size(), static_cast<size_t>(rank)) - in_pshape.size();
1608+
if (ones_to_add > 0) {
1609+
layout new_layout = in;
1610+
ov::PartialShape new_input_pshape;
1611+
std::vector<ov::Dimension> dims(in_pshape.begin(), in_pshape.begin() + in_pshape.size());
1612+
new_input_pshape = ov::PartialShape(dims);
1613+
new_input_pshape.insert(new_input_pshape.begin(), ones_to_add, 1ul);
1614+
new_layout.set_partial_shape(new_input_pshape);
1615+
in = new_layout;
1616+
}
15851617
}
15861618
size_t in_batched_size = in.count() / (in.spatial(0) * in.spatial(1));
15871619
dnnl::memory::dims dims = onednn::convert_gemm_tensor(in.get_tensor(), rank, in_batched_size == 1);

src/plugins/intel_gpu/tests/unit/fusions/fully_connected_fusion_test.cpp

+6-6
Original file line numberDiff line numberDiff line change
@@ -162,13 +162,13 @@ class FullyConnectedFusingTestOneDNN : public BaseFusingTest<fully_connected_tes
162162
#define CASE_FC_FP32_1 { 1, 3 }, { 1, 4 }, { 4, 3 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
163163
#define CASE_FC_FP32_2 { 2, 3 }, { 2, 4 }, { 4, 3 }, data_types::f32, format::yxfb, data_types::f32, format::oiyx, data_types::f32, format::bfyx
164164
#define CASE_FC_FP32_3 { 2, 32 }, { 2, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx
165-
#define CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
166-
#define CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
167-
#define CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
165+
#define CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
166+
#define CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
167+
#define CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
168168

169-
#define DYN_CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
170-
#define DYN_CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
171-
#define DYN_CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::f32, format::os_iyx_osv16, data_types::f32, format::bfyx
169+
#define DYN_CASE_FC_FP32_3D_1 { 5, 3, 3 }, { 5, 3, 5 }, { 5, 3 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
170+
#define DYN_CASE_FC_FP32_3D_2 { 2, 1, 1 }, { 2, 1, 32 }, { 32, 1 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
171+
#define DYN_CASE_FC_FP32_3D_3 { 2, 32, 32 }, { 2, 32, 16 }, { 16, 32 }, data_types::f32, format::bfyx, data_types::f32, format::oiyx, data_types::f32, format::bfyx
172172

173173
#define CASE_FC_U8S8_1 { 1, 3 }, { 1, 4 }, { 4, 3 }, data_types::u8, format::bfyx, data_types::i8, format::oiyx, data_types::f32, format::bfyx
174174
#define CASE_FC_U8S8_2 { 2, 3 }, { 2, 4 }, { 4, 3 }, data_types::u8, format::b_fs_yx_fsv4, data_types::i8, format::oiyx, data_types::f32, format::bfyx

0 commit comments

Comments
 (0)