Skip to content

Commit 459e8f0

Browse files
authored
[PT FE] Support converting f8 compressed models (#29432)
### Details: - *Support converting f8 compressed models* ### Tickets: - *CVS-164161* --------- Signed-off-by: Maxim Vafin <maxim.vafin@intel.com>
1 parent 4ecaa5d commit 459e8f0

File tree

7 files changed

+73
-14
lines changed

7 files changed

+73
-14
lines changed

src/bindings/python/src/openvino/frontend/pytorch/patch_model.py

+5-5
Original file line numberDiff line numberDiff line change
@@ -84,11 +84,12 @@ def __make_16bit_traceable(model: torch.nn.Module,
8484
- Replace known list of modules with ModuleExtension.
8585
- Convert other modules with weights to FP32.
8686
"""
87+
supported = {torch.float16, torch.bfloat16, torch.float8_e4m3fn, torch.float8_e5m2}
8788
if patch_condition is None:
8889
def patch_condition(module):
89-
supported = {torch.float32, torch.float16, torch.bfloat16}
90+
dtype_to_patch = {torch.float32, *supported}
9091
weight = getattr(module, "weight", None)
91-
return weight is not None and weight.dtype in supported
92+
return weight is not None and weight.dtype in dtype_to_patch
9293

9394
def fp32_tensor(*shape):
9495
return torch.full(shape, 0.5, dtype=torch.float32)
@@ -123,10 +124,9 @@ def fp32_tensor(*shape):
123124
except ImportError:
124125
pass
125126
patch_model(model, extensions, orig_forward_name)
126-
dtype_to_patch = {torch.float16, torch.bfloat16}
127127
for _, module in model.named_modules():
128128
if (module.__class__ not in extensions and
129-
(any(p.dtype in dtype_to_patch for p in module.parameters(False))
130-
or any(b.dtype in dtype_to_patch for b in module.buffers(False)))):
129+
(any(p.dtype in supported for p in module.parameters(False))
130+
or any(b.dtype in supported for b in module.buffers(False)))):
131131
log.debug("Casting module %s to float32", module)
132132
module.float()

src/common/transformations/include/transformations/fp16_compression/mark_decompression_convert_constant_folding.hpp

+2-2
Original file line numberDiff line numberDiff line change
@@ -62,10 +62,10 @@ class ov::pass::KeepConstantsPrecisionAndAddConverts : public MatcherPass {
6262

6363
/**
6464
* @ingroup ov_transformation_common_api
65-
* @brief Prevents ConstantFolding for f16/bf16 Const + Convert_To_FP32 to keep original FW float Constants.
65+
* @brief Prevents ConstantFolding for low precision Const + Convert_To_FP32 to keep original FW float Constants.
6666
* Original precision should be kept as long as possible, this prevents redundant conversions and saves memory.
6767
* E.g. if original FW model was already compressed no need to upcast during CF, store intermediate f32 consts and
68-
* then again compress them to f16 during save_model.
68+
* then again compress them to low precision during save_model.
6969
*/
7070
class ov::pass::MarkCompressedFloatConstants : public MatcherPass {
7171
public:

src/common/transformations/src/transformations/fp16_compression/mark_decompression_convert_constant_folding.cpp

+4-1
Original file line numberDiff line numberDiff line change
@@ -135,7 +135,10 @@ pass::MarkCompressedFloatConstants::MarkCompressedFloatConstants() {
135135
if (convert_node->get_destination_type() != element::f32)
136136
return false;
137137
if (const_node->get_output_element_type(0) != element::f16 &&
138-
const_node->get_output_element_type(0) != element::bf16)
138+
const_node->get_output_element_type(0) != element::bf16 &&
139+
const_node->get_output_element_type(0) != element::f8e4m3 &&
140+
const_node->get_output_element_type(0) != element::f8e5m2 &&
141+
const_node->get_output_element_type(0) != element::f8e8m0)
139142
return false;
140143

141144
mark_as_decompression(convert_node);

src/frontends/pytorch/src/frontend.cpp

+1-1
Original file line numberDiff line numberDiff line change
@@ -273,7 +273,7 @@ void FrontEnd::normalize(const std::shared_ptr<ov::Model>& model) const {
273273
manager.register_pass<ov::pass::ConvertConvertLike>();
274274
manager.register_pass<ov::frontend::pytorch::pass::AtenIndexToSelect>();
275275

276-
// Mark quantized and f16/bf16 compressed constants to prevent CF for them,
276+
// Mark low precision compressed constants to prevent CF for them,
277277
// so that not extra memory is used for intermediate decompressed constants.
278278
manager.register_pass<ov::pass::MarkCompressedFloatConstants>();
279279

src/frontends/pytorch/src/op/embedding.cpp

+3-1
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,9 @@ OutputVector translate_embedding_ext(const NodeContext& context) {
3333
// used in 16-bit patching
3434
num_inputs_check(context, 2, 5);
3535
auto data = context.get_input(0);
36-
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::f32));
36+
if (data.get_element_type() != element::f32) {
37+
data = context.mark_node(std::make_shared<ov::op::v0::Convert>(data, element::f32));
38+
}
3739
auto indices = context.get_input(1);
3840
indices = context.mark_node(std::make_shared<ov::op::v0::Convert>(indices, element::i32));
3941
auto axis_0 = context.mark_node(ov::op::v0::Constant::create(element::i32, Shape{}, {0}));

src/frontends/pytorch/src/op/linear.cpp

+2-3
Original file line numberDiff line numberDiff line change
@@ -36,9 +36,8 @@ OutputVector translate_linear_ext(const NodeContext& context) {
3636
auto x = context.get_input(0);
3737
auto initial_x = x;
3838
auto weight = context.get_input(1);
39-
bool is_compressed = weight.get_element_type() == element::f16 || weight.get_element_type() == element::bf16;
4039
bool convert_back = false;
41-
if (is_compressed) {
40+
if (weight.get_element_type() != element::f32) {
4241
// In case of patched linear it can have mixed fp16/bf16 and fp32 input type.
4342
// In other cases these conversion is not required.
4443
weight = context.mark_node(std::make_shared<v0::Convert>(weight, element::f32));
@@ -52,7 +51,7 @@ OutputVector translate_linear_ext(const NodeContext& context) {
5251
if (!context.input_is_none(2)) {
5352
auto bias = context.get_input(2);
5453

55-
if (bias.get_element_type() == element::f16 || bias.get_element_type() == element::bf16) {
54+
if (bias.get_element_type() != element::f32) {
5655
// Same reason as for weight.
5756
bias = context.mark_node(std::make_shared<v0::Convert>(bias, element::f32));
5857
}

tests/layer_tests/py_frontend_tests/test_torch_frontend.py

+56-1
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,6 @@ def sin_op(context):
428428
"Parameter", "Sin", "Result"]
429429

430430

431-
432431
def test_multiple_module_extension():
433432
from openvino.frontend.pytorch.ts_decoder import TorchScriptPythonDecoder
434433
from openvino.frontend.pytorch import ModuleExtension
@@ -764,6 +763,7 @@ def forward(self, x1, x2):
764763
np.testing.assert_allclose(res_bf16[0], res_ref[0].numpy(), atol=1e-2)
765764
np.testing.assert_allclose(res_bf16[1], res_ref[1].numpy(), atol=1e-2)
766765

766+
767767
def test_patched_16bit_model_with_convert():
768768
from openvino.frontend.pytorch import patch_model
769769
from openvino import convert_model, Type
@@ -797,6 +797,61 @@ def forward(self, x):
797797
assert mm_num == 2
798798

799799

800+
def test_patched_8bit_model_converts():
801+
from openvino.frontend.pytorch import patch_model
802+
from openvino import convert_model, compile_model
803+
from transformers.pytorch_utils import Conv1D
804+
805+
class ModelWithLinear(torch.nn.Module):
806+
def __init__(self):
807+
super().__init__()
808+
809+
self.branch1 = torch.nn.Sequential(
810+
torch.nn.Embedding(10, 64),
811+
torch.nn.Linear(64, 32),
812+
torch.nn.ReLU()
813+
)
814+
self.branch2 = torch.nn.Sequential(
815+
Conv1D(256, 128),
816+
torch.nn.Linear(256, 64), torch.nn.ReLU()
817+
)
818+
self.buffer = torch.ones(32)
819+
820+
def forward(self, x1, x2):
821+
out1 = self.branch1(x1)
822+
out2 = self.branch2(x2)
823+
return (out1 + self.buffer, out2)
824+
825+
example = (torch.randint(0, 10, [32, 64]), torch.randn(32, 128))
826+
827+
model_ref = ModelWithLinear().to(torch.float8_e4m3fn).float()
828+
with torch.no_grad():
829+
res_ref = model_ref(*example)
830+
model_f8_e4m3 = model_ref.to(torch.float8_e4m3fn)
831+
patch_model.__make_16bit_traceable(model_f8_e4m3)
832+
# the approach with patching only works for node with no grad
833+
with torch.no_grad():
834+
converted_model = convert_model(model_f8_e4m3, example_input=example)
835+
assert converted_model
836+
cm_f8_e4m3 = compile_model(converted_model, "CPU")
837+
res_f8_e4m3 = cm_f8_e4m3([x.numpy() for x in example])
838+
np.testing.assert_allclose(res_f8_e4m3[0], res_ref[0].numpy(), atol=1e-2)
839+
np.testing.assert_allclose(res_f8_e4m3[1], res_ref[1].numpy(), atol=1e-2)
840+
841+
model_ref = ModelWithLinear().to(torch.float8_e5m2).float()
842+
with torch.no_grad():
843+
res_ref = model_ref(*example)
844+
model_f8_e5m2 = model_ref.to(torch.float8_e5m2)
845+
patch_model.__make_16bit_traceable(model_f8_e5m2)
846+
# the approach with patching only works for node with no grad
847+
with torch.no_grad():
848+
converted_model = convert_model(model_f8_e5m2, example_input=example)
849+
assert converted_model
850+
cm_f8_e5m2 = compile_model(converted_model, "CPU")
851+
res_f8_e5m2 = cm_f8_e5m2([x.numpy() for x in example])
852+
np.testing.assert_allclose(res_f8_e5m2[0], res_ref[0].numpy(), atol=1e-2)
853+
np.testing.assert_allclose(res_f8_e5m2[1], res_ref[1].numpy(), atol=1e-2)
854+
800855

801856
class InlinedInputsModel(torch.nn.Module):
802857
def __init__(self):

0 commit comments

Comments
 (0)