@@ -26,18 +26,6 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
26
26
const auto &prim_dt = mem.dt ();
27
27
const auto &graph_dt = static_cast <dnnl_data_type_t >(lt.get_data_type ());
28
28
29
- // For int8 cases, as graph driver will modify the data type of leading
30
- // ops to u8/s8 in the reference path and use corresponding drivers to
31
- // generate data, special handling is needed. If it's found that data
32
- // type in ref path is u8/s8, it will be used.
33
- //
34
- // The reason why not always using primitive data type is that the driver
35
- // rewrites data type in graph path for bf16 case handling. So we prefer
36
- // data type in graph, and for int8 cases, that from ref path will be used.
37
- //
38
- dnnl_data_type_t c_data_type
39
- = prim_dt == dnnl_s8 || prim_dt == dnnl_u8 ? prim_dt : graph_dt;
40
-
41
29
// Get memory tag of primitive memory
42
30
int ndims = mem.ndims ();
43
31
dims_t strides (mem.strides (), mem.strides () + ndims);
@@ -58,7 +46,7 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
58
46
// otherwise use shape & tag from ref path side
59
47
60
48
// Create memory for graph path
61
- const auto data_type = static_cast <dnnl::memory::data_type>(c_data_type );
49
+ const auto data_type = static_cast <dnnl::memory::data_type>(graph_dt );
62
50
if (is_op_input) {
63
51
if (graph_dims_.empty ()) graph_dims_.push_back (1 );
64
52
if (graph_strides_.empty ()) graph_strides_.push_back (1 );
@@ -74,9 +62,9 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
74
62
std::memcpy (graph_data_handle, prim_data_handle, graph_mem.size ());
75
63
};
76
64
77
- if (prim_dt != c_data_type ) {
65
+ if (prim_dt != graph_dt ) {
78
66
dnn_mem_t c_mem (
79
- ndims, mem.dims (), c_data_type , mtag, ::get_test_engine ());
67
+ ndims, mem.dims (), graph_dt , mtag, ::get_test_engine ());
80
68
c_mem.reorder (mem);
81
69
prim_to_graph_memcpy (mem_, c_mem);
82
70
} else {
@@ -87,7 +75,7 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
87
75
dnnl::memory::desc md (graph_dims_, data_type, graph_strides_);
88
76
mem_ = dnn_mem_t (md.get (), ::get_test_engine ());
89
77
} else {
90
- mem_ = dnn_mem_t (mem.md_ , c_data_type , mtag, ::get_test_engine ());
78
+ mem_ = dnn_mem_t (mem.md_ , graph_dt , mtag, ::get_test_engine ());
91
79
}
92
80
}
93
81
}
0 commit comments