Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 1bd0290

Browse files
sunjiweiswiftDDEle
authored andcommitted
opt load_xe
1 parent 667af4b commit 1bd0290

File tree

3 files changed

+43
-41
lines changed

3 files changed

+43
-41
lines changed

include/subgroup/tile/impl/load_xe.hpp

+25-12
Original file line numberDiff line numberDiff line change
@@ -106,19 +106,31 @@ tile_load(tile_t& tile, payload_t& payload) {
106106
static constexpr bool mem_transform = payload_t::mem_transform;
107107

108108
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
109+
110+
// static constexpr uint32_t max_load_width_in_elem = trans
111+
// ? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
112+
// : load_store_attr::max_load_width_in_bytes / sizeof(dtype);
113+
// static constexpr uint32_t max_load_height_in_elem = trans
114+
// ? load_store_attr::max_trans_load_height_in_elem
115+
// : load_store_attr::max_load_height_in_elem;
116+
static constexpr uint32_t max_trans_load_width_in_elem =
117+
load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
118+
static constexpr uint32_t max_load_width_in_elem =
119+
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
120+
121+
// static constexpr uint32_t max_trans_load_height_in_elem =
122+
// load_store_attr::max_trans_load_height_in_elem;
123+
static constexpr uint32_t max_load_height_in_elem =
124+
load_store_attr::max_load_height_in_elem;
125+
109126
static constexpr uint32_t elems_per_CL =
110127
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
128+
111129
static constexpr uint32_t elems_per_reg =
112130
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
113-
static constexpr int32_t max_load_block_height =
114-
load_store_attr::max_load_height_in_elem;
115-
static constexpr int32_t max_block_width =
116-
load_store_attr::max_load_width_in_bytes / sizeof(dtype);
117-
static constexpr int32_t max_trans_block_width =
118-
load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype);
119131

120132
static constexpr uint32_t ld_blk_size_y_limit =
121-
mem_transpose ? max_trans_block_width : max_load_block_height;
133+
mem_transpose ? max_trans_load_width_in_elem : max_load_height_in_elem;
122134
static constexpr uint32_t ld_blk_size_y = reg_transpose
123135
? block_size_y
124136
: std::min(ld_blk_size_y_limit, block_size_y);
@@ -150,20 +162,21 @@ tile_load(tile_t& tile, payload_t& payload) {
150162

151163
static_assert(
152164
reg_transpose || mem_transpose ||
153-
(!mem_transpose && (block_size_x * arr_len) <= max_block_width),
165+
(!mem_transpose &&
166+
(block_size_x * arr_len) <= max_load_width_in_elem),
154167
"When reg_transpose was disabled, check 2d block width "
155168
"restriction");
156169
static_assert(
157170
!reg_transpose ||
158171
(!mem_transpose &&
159-
(block_size_x * arr_len) <= max_trans_block_width) ||
160-
(mem_transpose && (block_size_y * arr_len) <= max_block_width),
172+
(block_size_x * arr_len) <= max_trans_load_width_in_elem) ||
173+
(mem_transpose && (block_size_y * arr_len) <= max_load_width_in_elem),
161174
"When reg_transpose was enabled, check 2d block width "
162175
"restriction");
163176
static_assert(
164177
!reg_transpose ||
165-
(!mem_transpose && (block_size_y <= max_load_block_height)) ||
166-
(mem_transpose && (block_size_x) <= max_load_block_height),
178+
(!mem_transpose && (block_size_y <= max_load_height_in_elem)) ||
179+
(mem_transpose && (block_size_x) <= max_load_height_in_elem),
167180
"When reg_transpose was enabled, check 2d block height "
168181
"restriction");
169182
static_assert(

include/subgroup/tile/impl/payload_xe.hpp

-20
Original file line numberDiff line numberDiff line change
@@ -86,26 +86,6 @@ struct mem_payload_t<
8686
conditional_t<mem_transpose_dtype_less4bytes, uint32_t, dtype>;
8787
static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
8888

89-
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
90-
91-
static constexpr uint32_t max_load_width_in_elem = trans
92-
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
93-
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);
94-
static constexpr uint32_t max_load_height_in_elem = trans
95-
? load_store_attr::max_trans_load_height_in_elem
96-
: load_store_attr::max_load_height_in_elem;
97-
98-
static constexpr uint32_t max_store_width_in_elem =
99-
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
100-
static constexpr uint32_t max_store_height_in_elem =
101-
load_store_attr::max_store_height_in_elem;
102-
103-
static constexpr uint32_t elems_per_CL =
104-
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
105-
106-
static constexpr uint32_t elems_per_reg =
107-
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
108-
10989
dtype* base_ptr;
11090
uint32_t surface_width;
11191
uint32_t surface_height;

include/subgroup/tile/impl/store_xe.hpp

+18-9
Original file line numberDiff line numberDiff line change
@@ -99,19 +99,28 @@ tile_store(tile_t& tile, payload_t& payload) {
9999
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
100100
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
101101

102+
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
103+
104+
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
105+
static constexpr uint32_t max_store_width_in_elem =
106+
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
107+
static constexpr uint32_t max_store_height_in_elem =
108+
load_store_attr::max_store_height_in_elem;
109+
110+
static constexpr uint32_t elems_per_CL =
111+
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
112+
102113
static_assert(
103-
(payload_t::max_store_width_in_elem % block_size_x) == 0,
114+
(max_store_width_in_elem % block_size_x) == 0,
104115
"max_store_width_in_elem should be a multiply of block_size_x.");
105116

106117
static constexpr uint32_t st_blk_size_y =
107-
std::min(block_size_y, payload_t::max_store_height_in_elem);
118+
std::min(block_size_y, max_store_height_in_elem);
108119

109120
// to make sure full CL store
110-
static constexpr uint32_t st_blk_size_x =
111-
((tile_size_x % payload_t::elems_per_CL) == 0)
112-
? payload_t::elems_per_CL
113-
: (((payload_t::elems_per_CL % tile_size_x) == 0) ? tile_size_x
114-
: block_size_x);
121+
static constexpr uint32_t st_blk_size_x = ((tile_size_x % elems_per_CL) == 0)
122+
? elems_per_CL
123+
: (((elems_per_CL % tile_size_x) == 0) ? tile_size_x : block_size_x);
115124

116125
static constexpr uint8_t arr_len_candidate = st_blk_size_x / block_size_x;
117126
static constexpr bool is_valid_arr_len_candidate = (arr_len_candidate == 1) ||
@@ -120,14 +129,13 @@ tile_store(tile_t& tile, payload_t& payload) {
120129
static constexpr uint8_t arr_len =
121130
is_valid_arr_len_candidate ? arr_len_candidate : 1;
122131

123-
constexpr uint32_t store_block_elems = block_elems * arr_len;
124-
constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x;
125132
#pragma unroll
126133
for (uint32_t i = 0; i < num_block_y; ++i) {
127134
int32_t offset_y = i * block_size_y;
128135
#pragma unroll
129136
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
130137
int32_t offset_x = j * block_size_x;
138+
constexpr uint32_t store_block_elems = block_elems * arr_len;
131139
auto reg_blk = tile.reg.xetla_select<store_block_elems, 1>(
132140
(i * num_block_x + j) * block_elems);
133141
xetla_vector<dtype, store_block_elems> combine_blk;
@@ -150,6 +158,7 @@ tile_store(tile_t& tile, payload_t& payload) {
150158
}
151159
#pragma unroll
152160
for (uint32_t ii = 0; ii < block_size_y; ii += st_blk_size_y) {
161+
constexpr uint32_t store_elems = st_blk_size_y * st_blk_size_x;
153162
auto st_blk =
154163
combine_blk.xetla_select<store_elems, 1>(ii * st_blk_size_x);
155164
xetla_store_global<dtype, st_blk_size_x, st_blk_size_y, L1, L2>(

0 commit comments

Comments
 (0)