@@ -61,7 +61,8 @@ class ExecutorFactory {
61
61
m_postOps (postOps),
62
62
m_context(context),
63
63
m_suitableImplementations(filter(m_attrs, m_postOps, descriptors, implementationPriority)),
64
- m_implementationRequiresFallback(m_suitableImplementations.size(), true) {}
64
+ m_implementationRequiresFallback(m_suitableImplementations.size(), true),
65
+ m_executors(m_suitableImplementations.size()) {}
65
66
66
67
/* *
67
68
* @brief Retrieves the proper memory descriptors based on the provided memory descriptors.
@@ -106,12 +107,8 @@ class ExecutorFactory {
106
107
*/
107
108
void preconfigure (const MemoryArgs& memory) {
108
109
executor::Config<Attrs> config{memoryDescsFromMemory (memory), m_attrs, m_postOps};
109
- std::transform (m_suitableImplementations.begin (),
110
- m_suitableImplementations.end (),
111
- m_implementationRequiresFallback.begin (),
112
- [&config](const std::reference_wrapper<const ExecutorImplementation<Attrs>>& impl) {
113
- return impl.get ().requiresFallback (config);
114
- });
110
+
111
+ cacheFallbackStatus (config);
115
112
116
113
const size_t implId = select (memory, 0 );
117
114
const auto & impl = m_suitableImplementations[implId].get ();
@@ -123,7 +120,7 @@ class ExecutorFactory {
123
120
}
124
121
}
125
122
126
- (void )create (impl , memory, m_context);
123
+ (void )create (implId , memory, m_context);
127
124
}
128
125
129
126
/* *
@@ -154,7 +151,7 @@ class ExecutorFactory {
154
151
return fallback<Attrs, NodeT>(config, *fallbackConfig, memory, m_context, impl.name ());
155
152
}
156
153
}
157
- const auto executor = create (impl , memory, m_context);
154
+ const auto executor = create (implId , memory, m_context);
158
155
if (!executor->update (memory)) {
159
156
return nullptr ;
160
157
}
@@ -181,6 +178,19 @@ class ExecutorFactory {
181
178
182
179
return memoryDescs;
183
180
}
181
+
182
+ /* *
183
+ * @brief Caches the fallback status for each suitable implementation.
184
+ */
185
+ void cacheFallbackStatus (const executor::Config<Attrs>& config) {
186
+ std::transform (m_suitableImplementations.begin (),
187
+ m_suitableImplementations.end (),
188
+ m_implementationRequiresFallback.begin (),
189
+ [&config](const std::reference_wrapper<const ExecutorImplementation<Attrs>>& impl) {
190
+ return impl.get ().requiresFallback (config);
191
+ });
192
+ }
193
+
184
194
/* *
185
195
* @brief Filters and retrieves suitable implementations based on the provided executor configuration.
186
196
*
@@ -249,18 +259,19 @@ class ExecutorFactory {
249
259
return std::distance (m_suitableImplementations.begin (), selectedImplementation);
250
260
}
251
261
252
- ExecutorPtr create (const ExecutorImplementation<Attrs>& impl ,
262
+ ExecutorPtr create (const size_t implId ,
253
263
const MemoryArgs& memory,
254
264
const ExecutorContext::CPtr context) {
265
+ assert (implId < m_executors.size ());
266
+ auto executor = m_executors[implId];
267
+ if (executor)
268
+ return executor;
269
+
270
+ assert (implId < m_suitableImplementations.size ());
271
+ const auto & impl = m_suitableImplementations[implId].get ();
255
272
DEBUG_LOG (" Creating executor using implementation: " , impl.name ());
256
- const auto & executorId = std::make_pair (impl.type (), impl.operationType ());
257
- auto factoryIt = m_executors.find (executorId);
258
- if (factoryIt == m_executors.end ()) {
259
- factoryIt =
260
- m_executors.insert (std::make_pair (executorId, impl.create (m_attrs, m_postOps, memory, context))).first ;
261
- }
262
273
263
- return factoryIt-> second ;
274
+ return impl. create (m_attrs, m_postOps, memory, context) ;
264
275
}
265
276
266
277
const Attrs& m_attrs;
@@ -269,7 +280,8 @@ class ExecutorFactory {
269
280
std::vector<std::reference_wrapper<const ExecutorImplementation<Attrs>>> m_suitableImplementations;
270
281
// stores fallback status to avoid performing the check for every make() call
271
282
std::vector<bool > m_implementationRequiresFallback;
272
- std::map<std::pair<ExecutorType, OperationType>, ExecutorPtr> m_executors;
283
+ // executors cache
284
+ std::vector<ExecutorPtr> m_executors;
273
285
};
274
286
275
287
template <typename Attrs, typename NodeT>
0 commit comments