Skip to content

Commit 9587391

Browse files
committed
bit exact extension
1 parent b1d6550 commit 9587391

File tree

4 files changed

+63
-13
lines changed

4 files changed

+63
-13
lines changed

hls4ml/converters/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -165,6 +165,7 @@ def convert_from_keras_model(
165165
output_data_tb=None,
166166
backend='Vivado',
167167
hls_config=None,
168+
bit_exact=None,
168169
**kwargs,
169170
):
170171
"""Convert Keras model to hls4ml model based on the provided configuration.
@@ -214,6 +215,7 @@ def convert_from_keras_model(
214215

215216
model_config = hls_config.get('Model', None)
216217
config['HLSConfig']['Model'] = _check_model_config(model_config)
218+
config['HLSConfig']['Model']['BitExact'] = bit_exact
217219

218220
_check_hls_config(config, hls_config)
219221
if 'KerasModel' in config:

hls4ml/converters/keras/qkeras.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -174,6 +174,7 @@ def get_activation_quantizer(keras_layer, input_names, activation_name='activati
174174
layer[activation_name] = activation_config['class_name'].replace('quantized_', '')
175175

176176
layer[f'{activation_name}_quantizer'] = activation_config
177+
layer['trusted'] = True
177178

178179
return layer
179180

hls4ml/model/optimizer/passes/bit_exact.py

Lines changed: 43 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,6 +133,30 @@ def _(layer: Reshape):
133133
@_request_kif.register
134134
def _(layer: Activation):
135135
fn_name = layer.attributes.get('activation')
136+
137+
if layer.attributes.get('trusted', False):
138+
result_t = layer.get_output_variable().type.precision
139+
if fn_name in ('linear', 'relu'):
140+
output_shape = get_output_shape(layer)
141+
k, w, f = result_t.signed, result_t.width, result_t.fractional
142+
i = w - k - f
143+
k = np.full(output_shape, k, dtype=np.int8)
144+
i = np.full(output_shape, i, dtype=np.int8)
145+
f = np.full(output_shape, f, dtype=np.int8)
146+
if result_t.rounding_mode == RoundingMode.RND:
147+
f += 1
148+
elif result_t.rounding_mode != RoundingMode.TRN:
149+
f = np.full(output_shape, 126, dtype=np.int8)
150+
if result_t.saturation_mode != SaturationMode.WRAP:
151+
k = np.ones(output_shape, dtype=np.int8)
152+
i = np.full(output_shape, 126, dtype=np.int8)
153+
if fn_name == 'linear':
154+
return ((k, i, f),)
155+
else:
156+
k = np.ones(output_shape, dtype=np.int8)
157+
i = np.full(output_shape, 126, dtype=np.int8)
158+
return ((k, i, f),)
159+
136160
if fn_name == 'linear':
137161
return (requested_kif(layer),)
138162
if fn_name == 'relu':
@@ -531,6 +555,16 @@ def _(layer: Concatenate):
531555
@_produce_kif.register
532556
def _(layer: Activation):
533557
fn_name = layer.attributes['activation'].lower()
558+
if layer.attributes.get('trusted', False):
559+
output_shape = get_output_shape(layer)
560+
result_t = layer.get_output_variable().type.precision
561+
k, w, f = result_t.signed, result_t.width, result_t.fractional
562+
i = w - k - f
563+
k = np.full(output_shape, k, dtype=np.int8)
564+
i = np.full(output_shape, i, dtype=np.int8)
565+
f = np.full(output_shape, f, dtype=np.int8)
566+
return k, i, f
567+
534568
k, i, f = get_input_kifs(layer)[0]
535569

536570
match fn_name:
@@ -603,6 +637,10 @@ def requested_by_non_saturating_quantizer(layer: Layer) -> bool:
603637

604638

605639
def default_register_precision(layer: Layer):
640+
if layer.attributes.get('trusted', False):
641+
# Trusted layers have their precision already set
642+
return
643+
606644
_pk, _pi, _pf = produce_kif(layer) # Maximum possible k,i,f output from this layer
607645
_rk, _ri, _rf = requested_kif(layer) # Maximum possible k,i,f may be utilized by the next layer
608646
_oi, _of = np.minimum(_pi, _ri), np.minimum(_pf, _rf)
@@ -791,7 +829,11 @@ def has_fixed_quantizer(self, model: 'ModelGraph'):
791829
return True
792830

793831
def _match(self, model: 'ModelGraph'):
794-
return self.has_fixed_quantizer(model)
832+
enabled = model.config.config['HLSConfig']['Model'].get('BitExact', None)
833+
if enabled is None:
834+
# Enable by default if any FixedPointQuantizer is present
835+
enabled = self.has_fixed_quantizer(model)
836+
return enabled
795837

796838
def transform(self, model: 'ModelGraph'):
797839
if not self._match(model):

hls4ml/model/optimizer/passes/hgq_proxy_model.py

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
import numpy as np
77

88
from hls4ml.model.attributes import Attribute, TypeAttribute, WeightAttribute
9-
from hls4ml.model.layers import Layer, Reshape, register_layer
9+
from hls4ml.model.layers import Activation, Layer, Reshape, register_layer
1010
from hls4ml.model.optimizer import OptimizerPass, register_pass
1111
from hls4ml.model.types import FixedPrecisionType, UnspecifiedPrecisionType
1212

@@ -79,11 +79,13 @@ def userconf_ifdef(key: str, layer_name: str, model):
7979

8080
class FuseFixedPointQuantizer(OptimizerPass):
8181
def match(self, node: Layer):
82-
if not isinstance(node, FixedPointQuantizer):
83-
return False
84-
if any(np.unique(x).size > 1 for x in node.mask_kbi):
85-
return False
86-
return True
82+
if isinstance(node, FixedPointQuantizer):
83+
return all(np.unique(x).size == 1 for x in node.mask_kbi)
84+
85+
if isinstance(node, Activation):
86+
return node.get_attr('activation') == 'linear' and node.get_attr('trusted', False)
87+
88+
return False
8789

8890
def propagate(self, node: Layer, precision: FixedPrecisionType):
8991
from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers
@@ -115,13 +117,16 @@ def propagate(self, node: Layer, precision: FixedPrecisionType):
115117
def transform(self, model: 'ModelGraph', node: FixedPointQuantizer):
116118
from hls4ml.model.optimizer.passes.bit_exact import get_input_layers, get_output_layers
117119

118-
# Rounding and saturation for FixedPointQuantizer are applied in generated code, thus not reflected in result_t.
119-
if node.RND == 'TRN' and node.SAT == 'WRAP':
120-
precision: FixedPrecisionType = copy(node.get_output_variable().type.precision)
120+
if isinstance(node, FixedPointQuantizer):
121+
# Rounding and saturation for FixedPointQuantizer are applied in generated code, thus not reflected in result_t.
122+
if node.RND == 'TRN' and node.SAT == 'WRAP':
123+
precision: FixedPrecisionType = copy(node.get_output_variable().type.precision)
124+
else:
125+
k, b, i = node.mask_kbi
126+
k, b, i = bool(k.ravel()[0]), max(int(b.ravel()[0]), 1), int(i.ravel()[0])
127+
precision = FixedPrecisionType(b, i, k, node.RND, node.SAT)
121128
else:
122-
k, b, i = node.mask_kbi
123-
k, b, i = bool(k.ravel()[0]), max(int(b.ravel()[0]), 1), int(i.ravel()[0])
124-
precision = FixedPrecisionType(b, i, k, node.RND, node.SAT)
129+
precision = copy(node.get_output_variable().type.precision)
125130

126131
inp_layer = get_input_layers(node)[0]
127132
can_fuse = len(get_output_layers(inp_layer)) == 1

0 commit comments

Comments
 (0)