@@ -45,23 +45,10 @@ size_t get_benchdnn_device_limit() {
45
45
// Constructs memories for all inputs and outputs needed for comparison.
46
46
dnn_graph_mem_t::dnn_graph_mem_t (const dnn_mem_t &mem,
47
47
const deserialized_lt <, const bool is_op_input,
48
- const bool is_fake_output )
48
+ const bool use_graph_layout )
49
49
: graph_dims_(lt.shape_), graph_strides_(lt.stride_) {
50
- const auto &prim_dt = mem.dt ();
51
- // Conversion from graph types to dnnl types + boolean to u8.
52
- const auto &graph_dt = convert_dt (lt.get_data_type ());
53
-
54
- // Get memory tag of primitive memory
55
- int ndims = mem.ndims ();
56
- dims_t strides (mem.strides (), mem.strides () + ndims);
57
- std::string mtag = strides2memory_tag (ndims, strides);
58
-
59
50
const auto &g_eng = get_graph_engine ().operator const dnnl::engine &();
60
51
61
- // We create memory for graph path in two steps:
62
- // 1. Create memory objects.
63
- // 2. Do memory copy if needed.
64
- //
65
52
// For inputs, graph path needs data from reference path,
66
53
// and the data movement requires both memories have the same
67
54
// shape, so the tag of graph path is used to create the memory.
@@ -70,42 +57,77 @@ dnn_graph_mem_t::dnn_graph_mem_t(const dnn_mem_t &mem,
70
57
// otherwise use shape & tag from ref path side
71
58
72
59
// Create memory for graph path
60
+ const auto &graph_dt = convert_dt (lt.get_data_type ());
73
61
const auto data_type = static_cast <dnnl::memory::data_type>(graph_dt);
74
- if (is_op_input) {
75
- if (graph_dims_.empty ()) graph_dims_.push_back (1 );
76
- if (graph_strides_.empty ()) graph_strides_.push_back (1 );
77
62
78
- // create graph memory
63
+ if (graph_dims_.empty ()) {
64
+ // As graph strides are deduced from graph dims, they should be in
65
+ // compliance with each other.
66
+ assert (graph_strides_.empty ());
67
+
68
+ graph_dims_.push_back (1 );
69
+ graph_strides_.push_back (1 );
70
+ }
71
+
72
+ if (is_op_input) {
73
+ // Create graph memory with memory description from graph path.
79
74
dnnl::memory::desc md (graph_dims_, data_type, graph_strides_);
80
75
mem_ = dnn_mem_t (md.get (), g_eng.get ());
81
-
82
- const auto prim_to_graph_memcpy = [](dnn_mem_t &graph_mem,
83
- const dnn_mem_t &prim_mem) {
84
- const void *prim_data_handle = static_cast <const void *>(prim_mem);
85
- void *graph_data_handle = graph_mem.get_mapped_pointer <void >();
86
- std::memcpy (graph_data_handle, prim_data_handle, graph_mem.size ());
87
- };
88
-
89
- if (prim_dt != graph_dt) {
90
- // Call a reorder (for data conversion) when reference memory
91
- // doesn't coincide with the graph memory...
92
- dnn_mem_t c_mem (ndims, mem.dims (), graph_dt, mtag, g_eng.get ());
93
- SAFE_V (c_mem.reorder (mem));
94
- prim_to_graph_memcpy (mem_, c_mem);
95
- } else {
96
- // ... otherwise, perform a plain memcpy.
97
- prim_to_graph_memcpy (mem_, mem);
98
- }
99
76
} else {
100
- if (is_fake_output) {
77
+ if (use_graph_layout) {
78
+ // For some cases such as fake outputs and no reference memory
79
+ // mode, which means the output does not have correctponding
80
+ // argument in primitives, we need to create them with memory
81
+ // description from graph path.
101
82
dnnl::memory::desc md (graph_dims_, data_type, graph_strides_);
102
83
mem_ = dnn_mem_t (md.get (), g_eng.get ());
84
+
103
85
} else {
86
+ // Use information from the reference memory descriptor to create
87
+ // memories. As we need to reorder output from both paths to abx
88
+ // for comparison, the memory tag of graph path output should align
89
+ // the reference path.
90
+
91
+ // Get memory tag of primitive memory
92
+ int ndims = mem.ndims ();
93
+ dims_t strides (mem.strides (), mem.strides () + ndims);
94
+ std::string mtag = strides2memory_tag (ndims, strides);
95
+
104
96
mem_ = dnn_mem_t (mem.md_ , graph_dt, mtag, g_eng.get ());
105
97
}
106
98
}
107
99
}
108
100
101
+ int dnn_graph_mem_t::fill_mem_with_data (const dnn_mem_t &mem) {
102
+
103
+ if (mem.size () != mem_.size ()) return FAILED;
104
+
105
+ const auto &src_dt = mem.dt ();
106
+ const auto &dst_dt = mem_.dt ();
107
+
108
+ int ndims = mem.ndims ();
109
+ dims_t strides (mem.strides (), mem.strides () + ndims);
110
+ std::string mtag = strides2memory_tag (ndims, strides);
111
+ const auto &g_eng = get_graph_engine ().operator const dnnl::engine &();
112
+
113
+ const auto prim_to_graph_memcpy = [](dnn_mem_t &graph_mem,
114
+ const dnn_mem_t &prim_mem) {
115
+ const void *prim_data_handle = static_cast <const void *>(prim_mem);
116
+ void *graph_data_handle = graph_mem.get_mapped_pointer <void >();
117
+ std::memcpy (graph_data_handle, prim_data_handle, graph_mem.size ());
118
+ };
119
+
120
+ if (src_dt != dst_dt) {
121
+ dnn_mem_t c_mem (ndims, mem.dims (), dst_dt, mtag, g_eng.get ());
122
+ SAFE_V (c_mem.reorder (mem));
123
+ prim_to_graph_memcpy (mem_, c_mem);
124
+ } else {
125
+ prim_to_graph_memcpy (mem_, mem);
126
+ }
127
+
128
+ return OK;
129
+ }
130
+
109
131
dnnl::graph::tensor dnn_graph_mem_t::make_graph_tensor (
110
132
const deserialized_lt <) const {
111
133
void *data_handle;
0 commit comments