Skip to content

Commit 71123c8

Browse files
authored
bit exact extension (#1338)
* bit exact extension * add qkeras test, (maybe) support qonnx * revert onnx chnages * use int16 * disable quant fuse if bit_exact not enabled
1 parent 3f7ee79 commit 71123c8

File tree

5 files changed

+99
-31
lines changed

5 files changed

+99
-31
lines changed

hls4ml/converters/__init__.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def convert_from_keras_model(
165165
output_data_tb=None,
166166
backend='Vivado',
167167
hls_config=None,
168+
bit_exact=None,
168169
**kwargs,
169170
):
170171
"""Convert Keras model to hls4ml model based on the provided configuration.
@@ -194,6 +195,10 @@ def convert_from_keras_model(
194195
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
195196
hls_config (dict, optional): The HLS config.
196197
kwargs** (dict, optional): Additional parameters that will be used to create the config of the specified backend
198+
bit_exact (bool, optional): If True, enable model-wise precision propagation
199+
with **only fixed-point data types**. If None, enable if there is at least one
200+
FixedPointQuantizer layer in the model (only resulting from converting HGQ1/2
201+
models for now). By default, None.
197202
198203
Raises:
199204
Exception: If precision and reuse factor are not present in 'hls_config'.
@@ -214,6 +219,7 @@ def convert_from_keras_model(
214219

215220
model_config = hls_config.get('Model', None)
216221
config['HLSConfig']['Model'] = _check_model_config(model_config)
222+
config['HLSConfig']['Model']['BitExact'] = bit_exact
217223

218224
_check_hls_config(config, hls_config)
219225
if 'KerasModel' in config:
@@ -306,6 +312,7 @@ def convert_from_onnx_model(
306312
output_data_tb=None,
307313
backend='Vivado',
308314
hls_config=None,
315+
bit_exact=None,
309316
**kwargs,
310317
):
311318
"""Convert Keras model to hls4ml model based on the provided configuration.
@@ -335,6 +342,10 @@ def convert_from_onnx_model(
335342
'io_parallel' or 'io_stream'. Defaults to 'io_parallel'.
336343
hls_config (dict, optional): The HLS config.
337344
kwargs** (dict, optional): Additional parameters that will be used to create the config of the specified backend
345+
bit_exact (bool, optional): If True, enable model-wise precision propagation
346+
with **only fixed-point data types**. If None, enable if there is at least one
347+
FixedPointQuantizer layer in the model (only resulting from converting HGQ1/2
348+
models for now). By default, None.
338349
339350
Raises:
340351
Exception: If precision and reuse factor are not present in 'hls_config'.

hls4ml/converters/keras/qkeras.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def get_activation_quantizer(keras_layer, input_names, activation_name='activati
174174
layer[activation_name] = activation_config['class_name'].replace('quantized_', '')
175175

176176
layer[f'{activation_name}_quantizer'] = activation_config
177+
layer['trusted'] = True
177178

178179
return layer
179180

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,30 @@ def _(layer: Reshape):
133133
@_request_kif.register
134134
def _(layer: Activation):
135135
fn_name = layer.attributes.get('activation')
136+
137+
if layer.attributes.get('trusted', False):
138+
result_t = layer.get_output_variable().type.precision
139+
if fn_name in ('linear', 'relu'):
140+
output_shape = get_output_shape(layer)
141+
k, w, f = result_t.signed, result_t.width, result_t.fractional
142+
i = w - k - f
143+
k = np.full(output_shape, k, dtype=np.int16)
144+
i = np.full(output_shape, i, dtype=np.int16)
145+
f = np.full(output_shape, f, dtype=np.int16)
146+
if result_t.rounding_mode == RoundingMode.RND:
147+
f += 1
148+
elif result_t.rounding_mode != RoundingMode.TRN:
149+
f = np.full(output_shape, 126, dtype=np.int16)
150+
if result_t.saturation_mode != SaturationMode.WRAP:
151+
k = np.ones(output_shape, dtype=np.int16)
152+
i = np.full(output_shape, 126, dtype=np.int16)
153+
if fn_name == 'linear':
154+
return ((k, i, f),)
155+
else:
156+
k = np.ones(output_shape, dtype=np.int16)
157+
i = np.full(output_shape, 126, dtype=np.int16)
158+
return ((k, i, f),)
159+
136160
if fn_name == 'linear':
137161
return (requested_kif(layer),)
138162
if fn_name == 'relu':
@@ -533,6 +557,16 @@ def _(layer: Concatenate):
533557
@_produce_kif.register
534558
def _(layer: Activation):
535559
fn_name = layer.attributes['activation'].lower()
560+
if layer.attributes.get('trusted', False):
561+
output_shape = get_output_shape(layer)
562+
result_t = layer.get_output_variable().type.precision
563+
k, w, f = result_t.signed, result_t.width, result_t.fractional
564+
i = w - k - f
565+
k = np.full(output_shape, k, dtype=np.int16)
566+
i = np.full(output_shape, i, dtype=np.int16)
567+
f = np.full(output_shape, f, dtype=np.int16)
568+
return k, i, f
569+
536570
k, i, f = get_input_kifs(layer)[0]
537571

538572
match fn_name:
@@ -605,6 +639,10 @@ def requested_by_non_saturating_quantizer(layer: Layer) -> bool:
605639

606640

607641
def default_register_precision(layer: Layer):
642+
if layer.attributes.get('trusted', False):
643+
# Trusted layers have their precision already set
644+
return
645+
608646
_pk, _pi, _pf = produce_kif(layer) # Maximum possible k,i,f output from this layer
609647
_rk, _ri, _rf = requested_kif(layer) # Maximum possible k,i,f may be utilized by the next layer
610648
_oi, _of = np.minimum(_pi, _ri), np.minimum(_pf, _rf)
@@ -793,7 +831,11 @@ def has_fixed_quantizer(self, model: 'ModelGraph'):
793831
return True
794832

795833
def _match(self, model: 'ModelGraph'):
796-
return self.has_fixed_quantizer(model)
834+
enabled = model.config.config['HLSConfig']['Model'].get('BitExact', None)
835+
if enabled is None:
836+
# Enable by default if any FixedPointQuantizer is present
837+
enabled = self.has_fixed_quantizer(model)
838+
return enabled
797839

798840
def transform(self, model: 'ModelGraph'):
799841
if not self._match(model):

hls4ml/model/optimizer/passes/hgq_proxy_model.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88
from hls4ml.model.attributes import Attribute, TypeAttribute, WeightAttribute
9-
from hls4ml.model.layers import Layer, Reshape, register_layer
9+
from hls4ml.model.layers import Activation, Layer, Reshape, register_layer
1010
from hls4ml.model.optimizer import OptimizerPass, register_pass
1111
from hls4ml.model.types import FixedPrecisionType, UnspecifiedPrecisionType
1212

@@ -77,11 +77,16 @@ def userconf_ifdef(key: str, layer_name: str, model):
7777

7878
class FuseFixedPointQuantizer(OptimizerPass):
7979
def match(self, node: Layer):
80-
if not isinstance(node, FixedPointQuantizer):
81-
return False
82-
if any(np.unique(x).size > 1 for x in node.mask_kbi):
80+
if not node.attributes.get('bit_exact_transformed', False):
8381
return False
84-
return True
82+
83+
if isinstance(node, FixedPointQuantizer):
84+
return all(np.unique(x).size == 1 for x in node.mask_kbi)
85+
86+
if isinstance(node, Activation):
87+
return node.get_attr('activation') == 'linear' and node.get_attr('trusted', False)
88+
89+
return False
8590

8691
def propagate(self, node: Layer, precision: FixedPrecisionType):
8792
from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers
@@ -113,13 +118,16 @@ def propagate(self, node: Layer, precision: FixedPrecisionType):
113118
def transform(self, model: 'ModelGraph', node: FixedPointQuantizer):
114119
from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers
115120

116-
# Rounding and saturation for FixedPointQuantizer are applied in generated code, thus not reflected in result_t.
117-
if node.RND == 'TRN' and node.SAT == 'WRAP':
118-
precision: FixedPrecisionType = copy(node.get_output_variable().type.precision)
121+
if isinstance(node, FixedPointQuantizer):
122+
# Rounding and saturation for FixedPointQuantizer are applied in generated code, thus not reflected in result_t.
123+
if node.RND == 'TRN' and node.SAT == 'WRAP':
124+
precision: FixedPrecisionType = copy(node.get_output_variable().type.precision)
125+
else:
126+
k, b, i = node.mask_kbi
127+
k, b, i = bool(k.ravel()[0]), max(int(b.ravel()[0]), 1), int(i.ravel()[0])
128+
precision = FixedPrecisionType(b, i, k, node.RND, node.SAT)
119129
else:
120-
k, b, i = node.mask_kbi
121-
k, b, i = bool(k.ravel()[0]), max(int(b.ravel()[0]), 1), int(i.ravel()[0])
122-
precision = FixedPrecisionType(b, i, k, node.RND, node.SAT)
130+
precision = copy(node.get_output_variable().type.precision)
123131

124132
inp_layer = get_input_layers(node)[0]
125133
can_fuse = len(get_output_layers(inp_layer)) == 1

test/pytest/test_qkeras.py

Lines changed: 25 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@
33

44
import numpy as np
55
import pytest
6+
from keras.layers import BatchNormalization, Input
7+
from keras.models import Model, Sequential, model_from_json
8+
from keras.utils import to_categorical
69
from qkeras import QGRU, QLSTM, QSimpleRNN
710
from qkeras.qconv2d_batchnorm import QConv2DBatchnorm
811
from qkeras.qconvolutional import QDepthwiseConv2D, QSeparableConv1D, QSeparableConv2D
@@ -20,9 +23,6 @@
2023
from sklearn.datasets import fetch_openml
2124
from sklearn.model_selection import train_test_split
2225
from sklearn.preprocessing import LabelEncoder, StandardScaler
23-
from tensorflow.keras.layers import BatchNormalization, Input
24-
from tensorflow.keras.models import Model, Sequential, model_from_json
25-
from tensorflow.keras.utils import to_categorical
2626

2727
import hls4ml
2828

@@ -142,33 +142,39 @@ def test_single_dense_activation_exact(randX_100_16, bits, alpha, backend, io_ty
142142
bit exactness with number of bits parameter
143143
'''
144144
X = randX_100_16
145-
model = Sequential()
146-
model.add(
147-
QDense(
148-
16,
149-
input_shape=(16,),
150-
name='fc1',
151-
kernel_quantizer=quantized_bits(bits, 0, alpha=alpha),
152-
bias_quantizer=quantized_bits(bits, 0, alpha=1),
153-
kernel_initializer='lecun_uniform',
154-
)
145+
model = Sequential(
146+
[
147+
QActivation(activation=quantized_bits(bits, 0, alpha=1), input_shape=(16,), name='inp_quant'),
148+
QDense(
149+
16,
150+
name='fc1',
151+
kernel_quantizer=quantized_bits(bits, 0, alpha=alpha),
152+
bias_quantizer=quantized_bits(bits, 0, alpha=1),
153+
kernel_initializer='lecun_uniform',
154+
),
155+
QActivation(activation=quantized_relu(bits, 0), name='relu1'),
156+
]
155157
)
156-
model.add(QActivation(activation=quantized_relu(bits, 0), name='relu1'))
157158
model.compile()
158159

159160
config = hls4ml.utils.config_from_keras_model(model, granularity='name', backend=backend)
160161
output_dir = str(test_root_path / f'hls4mlprj_qkeras_single_dense_activation_exact_{bits}_{alpha}_{backend}_{io_type}')
162+
163+
bit_exact = alpha == 1
164+
# alpha!=po2 case uses non-fixed-point data types, unsupported by the precision propagation flow
161165
hls_model = hls4ml.converters.convert_from_keras_model(
162-
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type
166+
model, hls_config=config, output_dir=output_dir, backend=backend, io_type=io_type, bit_exact=bit_exact
163167
)
164168
hls_model.compile()
165169

166170
y_qkeras = model.predict(X)
167171
y_hls4ml = hls_model.predict(X)
168-
# Goal is to get it passing with all equal
169-
# np.testing.assert_array_equal(y_qkeras, y_hls4ml)
170-
# For now allow matching within 1 bit
171-
np.testing.assert_allclose(y_qkeras.ravel(), y_hls4ml.ravel(), atol=2**-bits, rtol=1.0)
172+
173+
# alpha!=1 case for weights can be supported if weight conversion is done before writing
174+
if bit_exact:
175+
np.testing.assert_array_equal(y_qkeras, y_hls4ml)
176+
else:
177+
np.testing.assert_allclose(y_qkeras.ravel(), y_hls4ml.ravel(), atol=2**-bits, rtol=1.0)
172178

173179

174180
@pytest.fixture

0 commit comments

Comments
 (0)