Skip to content

Commit fabfdee

Browse files
echereshdzarukin
authored andcommitted
xe: conv: improve persistent cache creation time
1 parent 7cdf520 commit fabfdee

File tree

4 files changed

+33
-3
lines changed

4 files changed

+33
-3
lines changed

src/gpu/intel/jit/conv/gen_convolution.cpp

+6-3
Original file line numberDiff line numberDiff line change
@@ -179,6 +179,11 @@ class gen_convolution_t {
179179
conv_config_t cfg;
180180
layout_t zp_dst;
181181
if (data.zp_pd) zp_dst = layout_t(data.zp_pd->impl()->dst_md(), false);
182+
183+
if (primitive->cache_blob()) {
184+
tiler->set_cur_index(primitive->version() - 1);
185+
}
186+
182187
for (int try_iter = 0; try_iter < max_tries; try_iter++) {
183188
try {
184189
cfg = data.pd_cfg;
@@ -187,8 +192,6 @@ class gen_convolution_t {
187192
cfg.set_tiler(tiler);
188193
CHECK(init_cfg(cfg, primitive));
189194

190-
if (primitive->cache_blob() && try_iter != primitive->version())
191-
continue;
192195
if (!tiler->is_grf_limit_ok(cfg)) continue;
193196

194197
ir_info() << "Configuration:" << std::endl;
@@ -256,7 +259,7 @@ class gen_convolution_t {
256259
if (!tmp_kernels[i]) return status::runtime_error;
257260
}
258261
ok = true;
259-
primitive->set_version(try_iter);
262+
primitive->set_version(tiler->cur_index());
260263
kernels_ = std::move(tmp_kernels);
261264
break;
262265
} catch (ngen::out_of_registers_exception &err) {

src/gpu/intel/jit/conv/tiler.cpp

+20
Original file line numberDiff line numberDiff line change
@@ -1262,6 +1262,8 @@ class conv_tuner_t {
12621262
params_gen_.move_next();
12631263
}
12641264

1265+
int cur_index() const { return params_gen_.cur_index(); }
1266+
12651267
void print_all() const { params_gen_.print_all(); }
12661268

12671269
static const primitive_info_t &get_primitive_info(
@@ -1448,6 +1450,16 @@ class conv_tiler_impl_t {
14481450
return params_gen_.can_move_next();
14491451
}
14501452

1453+
int cur_index() const {
1454+
if (is_tuning_mode()) return tuner_->cur_index();
1455+
return params_gen_.cur_index();
1456+
}
1457+
1458+
void set_cur_index(int idx) {
1459+
ir_assert(!is_tuning_mode());
1460+
return params_gen_.set_cur_index(idx);
1461+
}
1462+
14511463
void set_params(conv_config_t &cfg) {
14521464
init_regs(cfg);
14531465
if (is_tuning_mode()) {
@@ -1595,6 +1607,14 @@ bool conv_tiler_t::can_move_next() const {
15951607
return impl_->can_move_next();
15961608
}
15971609

1610+
int conv_tiler_t::cur_index() const {
1611+
return impl_->cur_index();
1612+
}
1613+
1614+
void conv_tiler_t::set_cur_index(int idx) {
1615+
impl_->set_cur_index(idx);
1616+
}
1617+
15981618
void conv_tiler_t::set_params(conv_config_t &cfg) {
15991619
impl_->set_params(cfg);
16001620
}

src/gpu/intel/jit/conv/tiler.hpp

+2
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,8 @@ class conv_tiler_t {
4242
int configs() const;
4343
bool is_tuning_mode() const;
4444
bool can_move_next() const;
45+
int cur_index() const;
46+
void set_cur_index(int idx);
4547
void set_params(conv_config_t &cfg);
4648
void notify_out_of_registers(const conv_config_t &cfg);
4749
bool is_grf_limit_ok(const conv_config_t &cfg) const;

src/gpu/intel/jit/ir/blocking.hpp

+5
Original file line numberDiff line numberDiff line change
@@ -557,6 +557,11 @@ class params_generator_t {
557557

558558
int cur_index() const { return cur_idx_; }
559559

560+
void set_cur_index(int idx) {
561+
ir_assert(idx < configs());
562+
cur_idx_ = idx;
563+
}
564+
560565
const blocking_params_t &cur_params() const { return at(cur_idx_); }
561566

562567
const blocking_params_t &at(int idx) const {

0 commit comments

Comments
 (0)