Skip to content

Commit 0bbc7a7

Browse files
densamoilovazhai219
authored andcommitted
cpu: rnn: use correct format for wei_proj
1 parent 6a7f6db commit 0bbc7a7

8 files changed

+16
-4
lines changed

include/oneapi/dnnl/dnnl.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -1495,6 +1495,7 @@ struct memory : public handle<dnnl_memory_t> {
14951495
AB8a2b = dnnl_AB8a2b,
14961496
abDc16d = dnnl_abDc16d,
14971497
abDc32d = dnnl_abDc32d,
1498+
abDC16d4c = dnnl_abDC16d4c,
14981499
abDC32d4c = dnnl_abDC32d4c,
14991500
abCd32c = dnnl_abCd32c,
15001501
abdEc16e = dnnl_abdEc16e,
@@ -1995,6 +1996,7 @@ struct memory : public handle<dnnl_memory_t> {
19951996

19961997
ldOi16o = abDc16d,
19971998
ldOi32o = abDc32d,
1999+
ldOI16o4i = abDC16d4c,
19982000
ldOI32o4i = abDC32d4c,
19992001
ldgOi16o = abdEc16e,
20002002
ldgOI16o4i = abdEC16e4c,

include/oneapi/dnnl/dnnl_types.h

+2
Original file line numberDiff line numberDiff line change
@@ -1044,6 +1044,7 @@ typedef enum {
10441044
dnnl_dabc,
10451045
dnnl_Ab32a,
10461046
dnnl_abdEC16e4c,
1047+
dnnl_abDC16d4c,
10471048

10481049
/// Just a sentinel, not real memory format tag. Must be changed after new
10491050
/// format tag is added.
@@ -1179,6 +1180,7 @@ typedef enum {
11791180
/// 5D LSTM projection tensor
11801181
dnnl_ldOi16o = dnnl_abDc16d,
11811182
dnnl_ldOi32o = dnnl_abDc32d,
1183+
dnnl_ldOI16o4i = dnnl_abDC16d4c,
11821184
dnnl_ldOI32o4i = dnnl_abDC32d4c,
11831185
dnnl_ldIo32i = dnnl_abCd32c,
11841186
/// 6D RNN weights tensor

src/common/c_types_map.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -710,6 +710,7 @@ const format_tag_t AB32a32b8a2b = dnnl_AB32a32b8a2b;
710710
const format_tag_t AB8a2b = dnnl_AB8a2b;
711711
const format_tag_t abDc16d = dnnl_abDc16d;
712712
const format_tag_t abDc32d = dnnl_abDc32d;
713+
const format_tag_t abDC16d4c = dnnl_abDC16d4c;
713714
const format_tag_t abDC32d4c = dnnl_abDC32d4c;
714715
const format_tag_t abCd4c = dnnl_abCd4c;
715716
const format_tag_t abCde4c = dnnl_abCde4c;
@@ -1480,6 +1481,7 @@ const format_tag_t gOIhw4o8i2o = dnnl_gOIhw4o8i2o;
14801481
const format_tag_t gOIdhw4o8i2o = dnnl_gOIdhw4o8i2o;
14811482
const format_tag_t ldOi16o = dnnl_ldOi16o;
14821483
const format_tag_t ldOi32o = dnnl_ldOi32o;
1484+
const format_tag_t ldOI16o4i = dnnl_ldOI16o4i;
14831485
const format_tag_t ldOI32o4i = dnnl_ldOI32o4i;
14841486
const format_tag_t ldIo32i = dnnl_ldIo32i;
14851487
const format_tag_t ldgOi16o = dnnl_ldgOi16o;

src/common/dnnl_debug_autogenerated.cpp

+2
Original file line numberDiff line numberDiff line change
@@ -381,6 +381,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
381381
if (v == dnnl_AB32a32b8a2b) return "AB32a32b8a2b";
382382
if (v == dnnl_AB8a2b) return "AB8a2b";
383383
if (v == dnnl_abDc32d) return "abDc32d";
384+
if (v == dnnl_abDC16d4c) return "abDC16d4c";
384385
if (v == dnnl_abDC32d4c) return "abDC32d4c";
385386
if (v == dnnl_abdEc32e) return "abdEc32e";
386387
if (v == dnnl_abdEC16e4c) return "abdEC16e4c";
@@ -1006,6 +1007,7 @@ const char *dnnl_fmt_tag2str(dnnl_format_tag_t v) {
10061007
if (v == dnnl_ldgo) return "ldgo";
10071008
if (v == dnnl_ldOi16o) return "ldOi16o";
10081009
if (v == dnnl_ldOi32o) return "ldOi32o";
1010+
if (v == dnnl_ldOI16o4i) return "ldOI16o4i";
10091011
if (v == dnnl_ldOI32o4i) return "ldOI32o4i";
10101012
if (v == dnnl_ldIo32i) return "ldIo32i";
10111013
if (v == dnnl_ldgOi16o) return "ldgOi16o";

src/common/memory_desc_wrapper.cpp

+1
Original file line numberDiff line numberDiff line change
@@ -635,6 +635,7 @@ status_t process_tag(F f, format_tag_t tag, Args&&... args) {
635635
C(AB8a2b, {0, 1}, {8, 2}, {0, 1});
636636
C(abDc16d, {0, 1, 3, 2}, {16}, {3});
637637
C(abDc32d, {0, 1, 3, 2}, {32}, {3});
638+
C(abDC16d4c, {0, 1, 3, 2}, {16, 4}, {3, 2});
638639
C(abDC32d4c, {0, 1, 3, 2}, {32, 4}, {3, 2});
639640
C(abCd4c, {0, 1, 2, 3}, {4}, {2});
640641
C(abCde4c, {0, 1, 2, 3, 4}, {4}, {2});

src/common/tag_traits.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -122,6 +122,7 @@ enum class inner_blk_t {
122122
_24b4c,
123123
_24c2b,
124124
_24c4b,
125+
_16d4c,
125126
_32d4c,
126127
_32e2c,
127128
_32e4c,
@@ -822,6 +823,7 @@ DECL_TRAITS(aBCde4c8b2c, _BC, _4c8b2c, 5);
822823
DECL_TRAITS(aBCdef4c8b2c, _BC, _4c8b2c, 6);
823824
DECL_TRAITS(abDc16d, _D, _16d, 4);
824825
DECL_TRAITS(abDc32d, _D, _32d, 4);
826+
DECL_TRAITS(abDC16d4c, _CD, _16d4c, 4);
825827
DECL_TRAITS(abDC32d4c, _CD, _32d4c, 4);
826828
DECL_TRAITS(abCd32c, _C, _32c, 4);
827829
DECL_TRAITS(abCde32c, _C, _32c, 5);

src/cpu/rnn/rnn_reorders.hpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -803,7 +803,7 @@ struct rnn_brgemm_weights_reorder_s8_t : public primitive_t {
803803

804804
itag = id.matches_one_of_tag(ldigo, ldio);
805805
otag = od.matches_one_of_tag(
806-
ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i);
806+
ldgOI64o4i, ldgOI32o4i, ldgOI16o4i, ldOI32o4i, ldOI16o4i);
807807
if (itag != format_tag::undef && otag != format_tag::undef) {
808808
_pd->itag_ = itag;
809809
_pd->otag_ = otag;

src/cpu/rnn/rnn_utils.cpp

+4-3
Original file line numberDiff line numberDiff line change
@@ -88,8 +88,8 @@ bool rnn_utils::is_ldgoi_blocked(const memory_desc_wrapper &mdw) {
8888
}
8989

9090
bool rnn_utils::is_ldio_blocked(const memory_desc_wrapper &mdw) {
91-
format_tag_t md_format_tag = mdw.matches_one_of_tag(
92-
format_tag::ldOi32o, format_tag::ldOI32o4i, format_tag::ldOi16o);
91+
format_tag_t md_format_tag = mdw.matches_one_of_tag(format_tag::ldOi32o,
92+
format_tag::ldOI32o4i, ldOI16o4i, format_tag::ldOi16o);
9393
return md_format_tag != format_tag::undef;
9494
}
9595

@@ -286,7 +286,8 @@ status_t rnn_utils::set_expected_desc(rnn_conf_t &rnn,
286286

287287
if (weights_type == weights_type_t::projection) {
288288
if (rnn.is_int8_conf())
289-
tag = format_tag::ldOI32o4i;
289+
tag = utils::map(n_block, format_tag::undef, 32,
290+
format_tag::ldOI32o4i, 16, format_tag::ldOI16o4i);
290291
else
291292
tag = utils::map(n_block, format_tag::undef, 32,
292293
format_tag::ldOi32o, 16, format_tag::ldOi16o);

0 commit comments

Comments
 (0)