Skip to content

Commit 605bdb5

Browse files
committed
Update precision assertion
1 parent ad2fea7 commit 605bdb5

File tree

1 file changed

+6
-8
lines changed

1 file changed

+6
-8
lines changed

src/plugins/intel_cpu/src/emitters/snippets/aarch64/jit_memory_emitters.cpp

+6-8
Original file line numberDiff line numberDiff line change
@@ -41,10 +41,9 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex
4141

4242
jit_load_memory_emitter::jit_load_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
4343
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) {
44-
OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
45-
"Unsupported input type: ", src_prc.get_type_name());
46-
OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
47-
"Unsupported output type: ", dst_prc.get_type_name());
44+
bool is_supported_precision = one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
45+
(src_prc == dst_prc || one_of(dst_prc, ov::element::f32, ov::element::i32));
46+
OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair.");
4847

4948
const auto load = std::dynamic_pointer_cast<snippets::op::Load>(expr->get_node());
5049
OV_CPU_JIT_EMITTER_ASSERT(load != nullptr, "Expects Load expression");
@@ -103,10 +102,9 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector<size_t> &in, const s
103102

104103
jit_store_memory_emitter::jit_store_memory_emitter(jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
105104
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::vec_to_gpr) {
106-
OV_CPU_JIT_EMITTER_ASSERT(one_of(src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
107-
"Unsupported input type: ", src_prc.get_type_name());
108-
OV_CPU_JIT_EMITTER_ASSERT(one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
109-
"Unsupported output type: ", dst_prc.get_type_name());
105+
bool is_supported_precision = one_of(dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
106+
(src_prc == dst_prc || one_of(src_prc, ov::element::f32, ov::element::i32));
107+
OV_CPU_JIT_EMITTER_ASSERT(is_supported_precision, "Unsupported precision pair.");
110108

111109
if (ov::is_type<ov::intel_cpu::StoreConvertTruncation>(expr->get_node())) {
112110
store_emitter.reset(new jit_store_emitter(h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::truncation));

0 commit comments

Comments
 (0)