@@ -74,52 +74,95 @@ status_t acl_lowp_matmul_sq_t::pd_t::init(engine_t *engine) {
74
74
const memory_desc_wrapper wei_d (weights_md_);
75
75
const memory_desc_wrapper bia_d (bias_md_);
76
76
const memory_desc_wrapper dst_d (dst_md_);
77
+
78
+ cpu::matmul::matmul_helper_t helper (src_d, wei_d, dst_d);
79
+ const dim_t M = helper.M ();
80
+ const dim_t N = helper.N ();
81
+ const dim_t K = helper.K ();
82
+ const dim_t dst_batch = helper.batch ();
83
+ const dim_t src_batch = helper.src_batch ();
84
+ const dim_t wei_batch = helper.wei_batch ();
85
+
77
86
using namespace data_type ;
78
87
VDISPATCH_MATMUL (utils::one_of (src_d.data_type (), s8, u8)
79
- && wei_d.data_type () == s8
80
- && src_d.data_type () == s8
81
- ? dst_d.data_type () == s8
82
- : dst_d.data_type () == u8,
88
+ && wei_d.data_type () == s8
89
+ && (src_d.data_type () == s8 ? dst_d.data_type () == s8
90
+ : dst_d.data_type () == u8),
83
91
VERBOSE_UNSUPPORTED_DT_CFG);
84
92
VDISPATCH_MATMUL (utils::one_of (bia_d.data_type (), f32, undef),
85
93
VERBOSE_UNSUPPORTED_DT_CFG);
86
- // reject in case the op is running in a Neoverse-N1.
94
+
95
+ // reject in case the op is running on a cpu that have i8mm instruction set.
96
+ // this is a temporary fix until the issue is resolved.
87
97
VDISPATCH_MATMUL (arm_compute::CPUInfo::get ().has_i8mm (),
88
- " Neoverse-N1 not supported" );
89
- VDISPATCH_MATMUL (src_d.matches_tag (format_tag::ab)
90
- && wei_d.matches_tag (format_tag::ab)
91
- && dst_d.matches_tag (format_tag::ab),
92
- VERBOSE_UNSUPPORTED_TAG);
93
- VDISPATCH_MATMUL_SC (
94
- memory_desc_init_by_tag (bias_md_, bias_md_.ndims , bias_md_.dims ,
95
- bias_md_.data_type , format_tag::ab),
98
+ " Op not supported on CPUs without i8mm instructions" );
99
+
100
+ // ACL batch dimension only support s32 for 3D and 4D
101
+ VDISPATCH_MATMUL (
102
+ wei_batch == 1 , " Batch dimension must be 1 for the weights" );
103
+
104
+ using namespace format_tag ;
105
+ auto src_tag = memory_desc_matches_one_of_tag (src_md_, abcd, abc, ab);
106
+ auto wei_tag = memory_desc_matches_one_of_tag (weights_md_, abcd, abc, ab);
107
+ auto dst_tag = memory_desc_matches_one_of_tag (dst_md_, abcd, abc, ab);
108
+
109
+ ACL_CHECK_SUPPORT (
110
+ utils::one_of (format_tag::undef, src_tag, wei_tag, dst_tag),
111
+ " Format tag is undefined" );
112
+
113
+ VDISPATCH_MATMUL_SC (memory_desc_init_by_tag (bias_md_, bias_md_.ndims ,
114
+ bias_md_.dims , bias_md_.data_type , dst_tag),
96
115
VERBOSE_UNSUPPORTED_BIAS_CFG);
97
- // We set the QuantizationInfo to be dynamic because it is re-set in run()
98
- almc_.src_tensor_info
99
- = arm_compute::TensorInfo (arm_compute::TensorShape (K (), M ()), 1 ,
100
- acl_utils::get_acl_data_t (src_d.data_type (), true ),
101
- arm_compute::QuantizationInfo (1.0 , 0 , true ));
116
+
117
+ almc_.bia_tensor_info = arm_compute::TensorInfo (
118
+ arm_compute::TensorShape (), 1 , arm_compute::DataType::S32);
119
+ almc_.with_bias = bia_d.format_kind () != format_kind::undef;
120
+
121
+ almc_.src_tensor_info = arm_compute::TensorInfo (
122
+ arm_compute::TensorShape (K, M, 1 , src_batch), 1 ,
123
+ acl_utils::get_acl_data_t (src_d.data_type (), true ),
124
+ arm_compute::QuantizationInfo (1.0 , 0 , true ));
102
125
almc_.src_tensor_info .set_are_values_constant (false );
103
- almc_.wei_tensor_info
104
- = arm_compute::TensorInfo (arm_compute::TensorShape (N (), K ()), 1 ,
105
- acl_utils::get_acl_data_t (wei_d.data_type (), true ),
106
- arm_compute::QuantizationInfo (1.0 , 0 , true ));
126
+
127
+ almc_.wei_tensor_info = arm_compute::TensorInfo (
128
+ arm_compute::TensorShape (N, K, 1 , wei_batch), 1 ,
129
+ acl_utils::get_acl_data_t (wei_d.data_type (), true ),
130
+ arm_compute::QuantizationInfo (1.0 , 0 , true ));
107
131
almc_.wei_tensor_info .set_are_values_constant (false );
108
- almc_.dst_tensor_info
109
- = arm_compute::TensorInfo (arm_compute::TensorShape (N (), M ()), 1 ,
110
- acl_utils::get_acl_data_t (dst_d.data_type (), true ),
111
- arm_compute::QuantizationInfo (1.0 , 0 , true ));
132
+ almc_.dst_tensor_info = arm_compute::TensorInfo (
133
+ arm_compute::TensorShape (N, M, 1 , dst_batch), 1 ,
134
+ acl_utils::get_acl_data_t (dst_d.data_type (), true ),
135
+ arm_compute::QuantizationInfo (1.0 , 0 , true ));
136
+
112
137
almc_.bia_tensor_info = arm_compute::TensorInfo (
113
138
arm_compute::TensorShape (), 1 , arm_compute::DataType::S32);
114
139
almc_.with_bias = bia_d.format_kind () != format_kind::undef;
140
+
115
141
if (almc_.with_bias ) {
116
- // This is not currently guarded in ACL
117
- VDISPATCH_MATMUL (bia_d.ndims () == 2 && bia_d.dims ()[0 ] == 1
118
- && bia_d.dims ()[1 ] == N (),
119
- " Only 1xN bias is supported" );
120
- almc_.bia_tensor_info .set_tensor_shape (
121
- arm_compute::TensorShape (bia_d.dims ()[1 ], bia_d.dims ()[0 ]));
142
+ switch (bia_d.ndims ()) {
143
+ case 2 :
144
+ VDISPATCH_MATMUL (bia_d.dims ()[0 ] == 1 && bia_d.dims ()[1 ] == N,
145
+ " Only 1xN bias is supported for 2D input" );
146
+ almc_.bia_tensor_info .set_tensor_shape (arm_compute::TensorShape (
147
+ bia_d.dims ()[1 ], bia_d.dims ()[0 ]));
148
+ break ;
149
+ case 3 :
150
+ VDISPATCH_MATMUL (bia_d.dims ()[0 ] == 1 && bia_d.dims ()[1 ] == 1
151
+ && bia_d.dims ()[2 ] == N,
152
+ " Only 1x1xN bias is supported for 3D input" );
153
+ almc_.bia_tensor_info .set_tensor_shape (
154
+ arm_compute::TensorShape (bia_d.dims ()[2 ], 1 , 1 ));
155
+ break ;
156
+ case 4 :
157
+ VDISPATCH_MATMUL (bia_d.dims ()[0 ] == 1 && bia_d.dims ()[1 ] == 1
158
+ && bia_d.dims ()[2 ] == 1 && bia_d.dims ()[3 ] == N,
159
+ " Only 1x1x1xN bias is supported for 4D input" );
160
+ almc_.bia_tensor_info .set_tensor_shape (
161
+ arm_compute::TensorShape (bia_d.dims ()[3 ], 1 , 1 , 1 ));
162
+ break ;
163
+ }
122
164
}
165
+
123
166
arm_compute::GEMMLowpOutputStageInfo info;
124
167
info.type = arm_compute::GEMMLowpOutputStageType::QUANTIZE_DOWN_FIXEDPOINT;
125
168
info.gemmlowp_multiplier = 1073741824 ;
@@ -132,15 +175,18 @@ status_t acl_lowp_matmul_sq_t::pd_t::init(engine_t *engine) {
132
175
auto scratchpad = scratchpad_registry ().registrar ();
133
176
const dnnl::impl::memory_desc_t dst_md_ {desc_.dst_desc };
134
177
arm_compute::ActivationLayerInfo act_info;
178
+
135
179
CHECK (init_scratchpad (engine, scratchpad, acl_post_ops, attr_.post_ops_ ,
136
180
act_info, dst_md_));
137
181
almc_.gemm_info .set_activation_info (act_info);
182
+
138
183
ACL_CHECK_VALID (arm_compute::NEGEMMLowpMatrixMultiplyCore::validate (
139
184
&almc_.src_tensor_info , &almc_.wei_tensor_info ,
140
185
almc_.with_bias ? &almc_.bia_tensor_info : nullptr ,
141
186
&almc_.dst_tensor_info , almc_.gemm_info ));
142
187
return status::success;
143
188
}
189
+
144
190
status_t acl_lowp_matmul_sq_t::pd_t::init_scratchpad (engine_t *engine,
145
191
memory_tracking::registrar_t &scratchpad, acl_post_ops_t &post_ops,
146
192
dnnl::impl::post_ops_t &attr_post_ops,
0 commit comments