@@ -41,10 +41,9 @@ jit_memory_emitter::jit_memory_emitter(jit_generator* h, cpu_isa_t isa, const Ex
41
41
42
42
jit_load_memory_emitter::jit_load_memory_emitter (jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
43
43
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::gpr_to_vec) {
44
- OV_CPU_JIT_EMITTER_ASSERT (one_of (src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
45
- " Unsupported input type: " , src_prc.get_type_name ());
46
- OV_CPU_JIT_EMITTER_ASSERT (one_of (dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
47
- " Unsupported output type: " , dst_prc.get_type_name ());
44
+ bool is_supported_precision = one_of (src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
45
+ (src_prc == dst_prc || one_of (dst_prc, ov::element::f32, ov::element::i32));
46
+ OV_CPU_JIT_EMITTER_ASSERT (is_supported_precision, " Unsupported precision pair." );
48
47
49
48
const auto load = std::dynamic_pointer_cast<snippets::op::Load>(expr->get_node ());
50
49
OV_CPU_JIT_EMITTER_ASSERT (load != nullptr , " Expects Load expression" );
@@ -103,10 +102,9 @@ void jit_load_broadcast_emitter::emit_isa(const std::vector<size_t> &in, const s
103
102
104
103
jit_store_memory_emitter::jit_store_memory_emitter (jit_generator* h, cpu_isa_t isa, const ExpressionPtr& expr)
105
104
: jit_memory_emitter(h, isa, expr, emitter_in_out_map::vec_to_gpr) {
106
- OV_CPU_JIT_EMITTER_ASSERT (one_of (src_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
107
- " Unsupported input type: " , src_prc.get_type_name ());
108
- OV_CPU_JIT_EMITTER_ASSERT (one_of (dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8),
109
- " Unsupported output type: " , dst_prc.get_type_name ());
105
+ bool is_supported_precision = one_of (dst_prc, ov::element::f32, ov::element::i32, ov::element::f16, ov::element::i8, ov::element::u8) &&
106
+ (src_prc == dst_prc || one_of (src_prc, ov::element::f32, ov::element::i32));
107
+ OV_CPU_JIT_EMITTER_ASSERT (is_supported_precision, " Unsupported precision pair." );
110
108
111
109
if (ov::is_type<ov::intel_cpu::StoreConvertTruncation>(expr->get_node ())) {
112
110
store_emitter.reset (new jit_store_emitter (h, isa, src_prc, dst_prc, count, byte_offset, arithmetic_mode::truncation));
0 commit comments