Skip to content

Commit 53555b3

Browse files
authored
Inline requantize kernels
Differential Revision: D80772740 Pull Request resolved: #13592
1 parent b7886d0 commit 53555b3

File tree

2 files changed

+18
-106
lines changed

2 files changed

+18
-106
lines changed

backends/cadence/reference/kernels/kernels.cpp

Lines changed: 0 additions & 84 deletions
Original file line numberDiff line numberDiff line change
@@ -11,8 +11,6 @@
1111
#include <algorithm>
1212
#include <cstring>
1313
#include <limits>
14-
#include <numeric>
15-
1614
namespace impl {
1715
namespace reference {
1816
namespace kernels {
@@ -58,36 +56,6 @@ void dequantize(
5856
}
5957
}
6058

61-
// Requantize the int8_t/uint8_t in value to a uint8_t/int8_t out value.
62-
// The scale and zero_point for requantization are in the args.
63-
template <typename IT, typename OT>
64-
OT requantize(
65-
const IT in,
66-
float in_scale,
67-
int32_t in_zero_point,
68-
float inv_out_scale,
69-
int32_t out_zero_point) {
70-
float dequant = dequantize<IT>(in, in_scale, in_zero_point);
71-
return quantize<OT>(dequant, inv_out_scale, out_zero_point);
72-
}
73-
74-
// Requantize the int8_t/uint8_t in array to a uint8_t/int8_t out array.
75-
// The scale and zero_point for requantization are in the args.
76-
template <typename IT, typename OT>
77-
void requantize(
78-
OT* __restrict__ out,
79-
const IT* __restrict__ in,
80-
float in_scale,
81-
int32_t in_zero_point,
82-
float inv_out_scale,
83-
int32_t out_zero_point,
84-
size_t size) {
85-
for (size_t i = 0; i < size; ++i) {
86-
out[i] = requantize<IT, OT>(
87-
in[i], in_scale, in_zero_point, inv_out_scale, out_zero_point);
88-
}
89-
}
90-
9159
// explicit template instantiation
9260

9361
#define typed_quantize_val(dtype) \
@@ -136,58 +104,6 @@ typed_dequantize_vec(uint16_t);
136104
typed_dequantize_vec(int32_t);
137105
#undef typed_dequantize_vec
138106

139-
#define typed_requantize_val(itype, otype) \
140-
template otype requantize( \
141-
const itype in, \
142-
float in_scale, \
143-
int32_t in_zero_point, \
144-
float inv_out_scale, \
145-
int32_t out_zero_point);
146-
typed_requantize_val(int8_t, int8_t);
147-
typed_requantize_val(int8_t, uint8_t);
148-
typed_requantize_val(int8_t, int16_t);
149-
typed_requantize_val(int8_t, uint16_t);
150-
typed_requantize_val(uint8_t, int8_t);
151-
typed_requantize_val(uint8_t, uint8_t);
152-
typed_requantize_val(uint8_t, int16_t);
153-
typed_requantize_val(uint8_t, uint16_t);
154-
typed_requantize_val(int16_t, int8_t);
155-
typed_requantize_val(int16_t, uint8_t);
156-
typed_requantize_val(int16_t, int16_t);
157-
typed_requantize_val(int16_t, uint16_t);
158-
typed_requantize_val(uint16_t, int8_t);
159-
typed_requantize_val(uint16_t, uint8_t);
160-
typed_requantize_val(uint16_t, int16_t);
161-
typed_requantize_val(uint16_t, uint16_t);
162-
#undef typed_requantize_val
163-
164-
#define typed_requantize_vec(itype, otype) \
165-
template void requantize( \
166-
otype* __restrict__ out, \
167-
const itype* __restrict__ in, \
168-
float in_scale, \
169-
int32_t in_zero_point, \
170-
float inv_out_scale, \
171-
int32_t out_zero_point, \
172-
size_t size);
173-
typed_requantize_vec(int8_t, int8_t);
174-
typed_requantize_vec(int8_t, uint8_t);
175-
typed_requantize_vec(int8_t, int16_t);
176-
typed_requantize_vec(int8_t, uint16_t);
177-
typed_requantize_vec(uint8_t, int8_t);
178-
typed_requantize_vec(uint8_t, uint8_t);
179-
typed_requantize_vec(uint8_t, int16_t);
180-
typed_requantize_vec(uint8_t, uint16_t);
181-
typed_requantize_vec(int16_t, int8_t);
182-
typed_requantize_vec(int16_t, uint8_t);
183-
typed_requantize_vec(int16_t, int16_t);
184-
typed_requantize_vec(int16_t, uint16_t);
185-
typed_requantize_vec(uint16_t, int8_t);
186-
typed_requantize_vec(uint16_t, uint8_t);
187-
typed_requantize_vec(uint16_t, int16_t);
188-
typed_requantize_vec(uint16_t, uint16_t);
189-
#undef typed_requantize_vec
190-
191107
}; // namespace kernels
192108
}; // namespace reference
193109
}; // namespace impl

backends/cadence/reference/operators/requantize_out.cpp renamed to backends/cadence/reference/operators/op_requantize_out.cpp

Lines changed: 18 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -86,17 +86,15 @@ Tensor& requantize_out(
8686
torch::executor::toString(out.scalar_type()),
8787
torch::executor::toString(out_dtype));
8888

89-
#define typed_requantize(ctype, dtype) \
90-
const ctype* input_data = input.const_data_ptr<ctype>(); \
91-
dtype* out_data = out.mutable_data_ptr<dtype>(); \
92-
kernels::requantize<ctype, dtype>( \
93-
out_data, \
94-
input_data, \
95-
in_scale, \
96-
in_zero_point, \
97-
1.0 / out_scale, \
98-
out_zero_point, \
99-
numel);
89+
#define typed_requantize(ctype, dtype) \
90+
const ctype* input_data = input.const_data_ptr<ctype>(); \
91+
dtype* out_data = out.mutable_data_ptr<dtype>(); \
92+
for (size_t i = 0; i < numel; ++i) { \
93+
float dequant = \
94+
kernels::dequantize<ctype>(input_data[i], in_scale, in_zero_point); \
95+
out_data[i] = \
96+
kernels::quantize<dtype>(dequant, 1 / out_scale, out_zero_point); \
97+
};
10098

10199
#define typed_requantize_in(ctype) \
102100
switch (out_dtype) { \
@@ -190,17 +188,15 @@ Tensor& requantize_per_tensor_out(
190188
torch::executor::toString(out.scalar_type()),
191189
torch::executor::toString(out_dtype));
192190

193-
#define typed_requantize(ctype, dtype) \
194-
const ctype* input_data = input.const_data_ptr<ctype>(); \
195-
dtype* out_data = out.mutable_data_ptr<dtype>(); \
196-
kernels::requantize<ctype, dtype>( \
197-
out_data, \
198-
input_data, \
199-
static_cast<float>(in_scale), \
200-
static_cast<int32_t>(in_zero_point), \
201-
1.0 / static_cast<float>(out_scale), \
202-
static_cast<int32_t>(out_zero_point), \
203-
numel);
191+
#define typed_requantize(ctype, dtype) \
192+
const ctype* input_data = input.const_data_ptr<ctype>(); \
193+
dtype* out_data = out.mutable_data_ptr<dtype>(); \
194+
for (size_t i = 0; i < numel; ++i) { \
195+
float dequant = \
196+
kernels::dequantize<ctype>(input_data[i], in_scale, in_zero_point); \
197+
out_data[i] = \
198+
kernels::quantize<dtype>(dequant, 1 / out_scale, out_zero_point); \
199+
};
204200

205201
#define typed_requantize_in(ctype) \
206202
switch (out_dtype) { \

0 commit comments

Comments
 (0)