Skip to content

Commit d94f87c

Browse files
wzt1997thuang6
authored andcommitted
tests: benchdnn: graph: remove useless data type alignment
1 parent 35ed364 commit d94f87c

File tree

1 file changed

+4
-16
lines changed

1 file changed

+4
-16
lines changed

tests/benchdnn/graph/graph_memory.cpp

+4-16
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,6 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
2626
const auto &prim_dt = mem.dt();
2727
const auto &graph_dt = static_cast<dnnl_data_type_t>(lt.get_data_type());
2828

29-
// For int8 cases, as graph driver will modify the data type of leading
30-
// ops to u8/s8 in the reference path and use corresponding drivers to
31-
// generate data, special handling is needed. If it's found that data
32-
// type in ref path is u8/s8, it will be used.
33-
//
34-
// The reason why not always using primitive data type is that the driver
35-
// rewrites data type in graph path for bf16 case handling. So we prefer
36-
// data type in graph, and for int8 cases, that from ref path will be used.
37-
//
38-
dnnl_data_type_t c_data_type
39-
= prim_dt == dnnl_s8 || prim_dt == dnnl_u8 ? prim_dt : graph_dt;
40-
4129
// Get memory tag of primitive memory
4230
int ndims = mem.ndims();
4331
dims_t strides(mem.strides(), mem.strides() + ndims);
@@ -58,7 +46,7 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
5846
// otherwise use shape & tag from ref path side
5947

6048
// Create memory for graph path
61-
const auto data_type = static_cast<dnnl::memory::data_type>(c_data_type);
49+
const auto data_type = static_cast<dnnl::memory::data_type>(graph_dt);
6250
if (is_op_input) {
6351
if (graph_dims_.empty()) graph_dims_.push_back(1);
6452
if (graph_strides_.empty()) graph_strides_.push_back(1);
@@ -74,9 +62,9 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
7462
std::memcpy(graph_data_handle, prim_data_handle, graph_mem.size());
7563
};
7664

77-
if (prim_dt != c_data_type) {
65+
if (prim_dt != graph_dt) {
7866
dnn_mem_t c_mem(
79-
ndims, mem.dims(), c_data_type, mtag, ::get_test_engine());
67+
ndims, mem.dims(), graph_dt, mtag, ::get_test_engine());
8068
c_mem.reorder(mem);
8169
prim_to_graph_memcpy(mem_, c_mem);
8270
} else {
@@ -87,7 +75,7 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
8775
dnnl::memory::desc md(graph_dims_, data_type, graph_strides_);
8876
mem_ = dnn_mem_t(md.get(), ::get_test_engine());
8977
} else {
90-
mem_ = dnn_mem_t(mem.md_, c_data_type, mtag, ::get_test_engine());
78+
mem_ = dnn_mem_t(mem.md_, graph_dt, mtag, ::get_test_engine());
9179
}
9280
}
9381
}

0 commit comments

Comments
 (0)