Skip to content

Commit 3bcd2f4

Browse files
committed
benchdnn: graph: separate mem filling and create mem from graph path
1 parent 8c811df commit 3bcd2f4

File tree

2 files changed

+66
-41
lines changed

2 files changed

+66
-41
lines changed

tests/benchdnn/graph/graph_memory.cpp

+59-37
Original file line numberDiff line numberDiff line change
@@ -45,23 +45,10 @@ size_t get_benchdnn_device_limit() {
4545
// Constructs memories for all inputs and outputs needed for comparison.
4646
dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
4747
const deserialized_lt &lt, const bool is_op_input,
48-
const bool is_fake_output)
48+
const bool use_graph_layout)
4949
: graph_dims_(lt.shape_), graph_strides_(lt.stride_) {
50-
const auto &prim_dt = mem.dt();
51-
// Conversion from graph types to dnnl types + boolean to u8.
52-
const auto &graph_dt = convert_dt(lt.get_data_type());
53-
54-
// Get memory tag of primitive memory
55-
int ndims = mem.ndims();
56-
dims_t strides(mem.strides(), mem.strides() + ndims);
57-
std::string mtag = strides2memory_tag(ndims, strides);
58-
5950
const auto &g_eng = get_graph_engine().operator const dnnl::engine &();
6051

61-
// We create memory for graph path in two steps:
62-
// 1. Create memory objects.
63-
// 2. Do memory copy if needed.
64-
//
6552
// For inputs, graph path needs data from reference path,
6653
// and the data movement requires both memories have the same
6754
// shape, so the tag of graph path is used to create the memory.
@@ -70,42 +57,77 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
7057
// otherwise use shape & tag from ref path side
7158

7259
// Create memory for graph path
60+
const auto &graph_dt = convert_dt(lt.get_data_type());
7361
const auto data_type = static_cast<dnnl::memory::data_type>(graph_dt);
74-
if (is_op_input) {
75-
if (graph_dims_.empty()) graph_dims_.push_back(1);
76-
if (graph_strides_.empty()) graph_strides_.push_back(1);
7762

78-
// create graph memory
63+
if (graph_dims_.empty()) {
64+
// As graph strides are deduced from graph dims, they should be in
65+
// compliance with each other.
66+
assert(graph_strides_.empty());
67+
68+
graph_dims_.push_back(1);
69+
graph_strides_.push_back(1);
70+
}
71+
72+
if (is_op_input) {
73+
// Create graph memory with memory description from graph path.
7974
dnnl::memory::desc md(graph_dims_, data_type, graph_strides_);
8075
mem_ = dnn_mem_t(md.get(), g_eng.get());
81-
82-
const auto prim_to_graph_memcpy = [](dnn_mem_t &graph_mem,
83-
const dnn_mem_t &prim_mem) {
84-
const void *prim_data_handle = static_cast<const void *>(prim_mem);
85-
void *graph_data_handle = graph_mem.get_mapped_pointer<void>();
86-
std::memcpy(graph_data_handle, prim_data_handle, graph_mem.size());
87-
};
88-
89-
if (prim_dt != graph_dt) {
90-
// Call a reorder (for data conversion) when reference memory
91-
// doesn't coincide with the graph memory...
92-
dnn_mem_t c_mem(ndims, mem.dims(), graph_dt, mtag, g_eng.get());
93-
SAFE_V(c_mem.reorder(mem));
94-
prim_to_graph_memcpy(mem_, c_mem);
95-
} else {
96-
// ... otherwise, perform a plain memcpy.
97-
prim_to_graph_memcpy(mem_, mem);
98-
}
9976
} else {
100-
if (is_fake_output) {
77+
if (use_graph_layout) {
78+
// For some cases such as fake outputs and no reference memory
79+
// mode, which means the output does not have correctponding
80+
// argument in primitives, we need to create them with memory
81+
// description from graph path.
10182
dnnl::memory::desc md(graph_dims_, data_type, graph_strides_);
10283
mem_ = dnn_mem_t(md.get(), g_eng.get());
84+
10385
} else {
86+
// Use information from the reference memory descriptor to create
87+
// memories. As we need to reorder output from both paths to abx
88+
// for comparison, the memory tag of graph path output should align
89+
// the reference path.
90+
91+
// Get memory tag of primitive memory
92+
int ndims = mem.ndims();
93+
dims_t strides(mem.strides(), mem.strides() + ndims);
94+
std::string mtag = strides2memory_tag(ndims, strides);
95+
10496
mem_ = dnn_mem_t(mem.md_, graph_dt, mtag, g_eng.get());
10597
}
10698
}
10799
}
108100

101+
int dnn_graph_mem_t::fill_mem_with_data(const dnn_mem_t &mem) {
102+
103+
if (mem.size() != mem_.size()) return FAILED;
104+
105+
const auto &src_dt = mem.dt();
106+
const auto &dst_dt = mem_.dt();
107+
108+
int ndims = mem.ndims();
109+
dims_t strides(mem.strides(), mem.strides() + ndims);
110+
std::string mtag = strides2memory_tag(ndims, strides);
111+
const auto &g_eng = get_graph_engine().operator const dnnl::engine &();
112+
113+
const auto prim_to_graph_memcpy = [](dnn_mem_t &graph_mem,
114+
const dnn_mem_t &prim_mem) {
115+
const void *prim_data_handle = static_cast<const void *>(prim_mem);
116+
void *graph_data_handle = graph_mem.get_mapped_pointer<void>();
117+
std::memcpy(graph_data_handle, prim_data_handle, graph_mem.size());
118+
};
119+
120+
if (src_dt != dst_dt) {
121+
dnn_mem_t c_mem(ndims, mem.dims(), dst_dt, mtag, g_eng.get());
122+
SAFE_V(c_mem.reorder(mem));
123+
prim_to_graph_memcpy(mem_, c_mem);
124+
} else {
125+
prim_to_graph_memcpy(mem_, mem);
126+
}
127+
128+
return OK;
129+
}
130+
109131
dnnl::graph::tensor dnn_graph_mem_t::make_graph_tensor(
110132
const deserialized_lt &lt) const {
111133
void *data_handle;

tests/benchdnn/graph/graph_memory.hpp

+7-4
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,15 @@ struct dnn_graph_mem_t {
156156
//
157157
// The constructor accepts three boolean parameters:
158158
// 1. is_op_input: whether the logical tensor is an input of an op
159-
// 2. is_fake_output: for fake outputs, the driver cannot create memory
160-
// objects based on primitive memory for them, but construct memory
161-
// from graph shape. The default value is false.
159+
// 2. use_graph_layout: for fake outputs and mode without reference
160+
// memories, the driver cannot create memory objects based on primitive
161+
// memory for them, but construct memory from graph shape. The default
162+
// value is false.
162163
//
163164
dnn_graph_mem_t(const dnn_mem_t &mem, const deserialized_lt &lt,
164-
const bool is_op_input, const bool is_fake_output = false);
165+
const bool is_op_input, const bool use_graph_layout = false);
166+
167+
int fill_mem_with_data(const dnn_mem_t &mem);
165168

166169
dnnl::graph::tensor make_graph_tensor(const deserialized_lt &lt) const;
167170

0 commit comments

Comments
 (0)