Skip to content
This repository was archived by the owner on Aug 30, 2024. It is now read-only.

Commit 667af4b

Browse files
sunjiweiswiftDDEle
authored andcommitted
opt store_xe.hpp
1 parent 99ba82e commit 667af4b

File tree

4 files changed

+90
-145
lines changed

4 files changed

+90
-145
lines changed

include/common/core/arch_config.hpp

+3
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ struct load_store_attr_t<msg_type::block_2d, gpu_arch::XeHpc> {
4747
// BlockWidth must be 1,2,4 for qwords and be in range [1..8] for dwords.
4848
static constexpr uint32_t max_trans_load_width_in_bytes = 32;
4949

50+
// BlockHeight must be 8 for qwords and be in range [1..32] for dwords.
51+
static constexpr uint32_t max_trans_load_height_in_elem = 32;
52+
5053
// If Transformed is true
5154
// BlockWidth must be in range [4..16] for bytes and [2..16] for word.
5255
static constexpr uint32_t max_vnni_load_width_in_elems = 16;

include/subgroup/tile/impl/load_xe.hpp

+8-15
Original file line numberDiff line numberDiff line change
@@ -89,7 +89,6 @@ tile_load(tile_t& tile, payload_t& payload) {
8989

9090
static constexpr uint32_t num_block_x = tile_desc::num_block_x;
9191
static constexpr uint32_t num_block_y = tile_desc::num_block_y;
92-
// static constexpr uint32_t num_block = tile_desc::num_block;
9392

9493
static constexpr gpu_arch arch_tag = payload_t::arch_tag;
9594

@@ -181,19 +180,9 @@ tile_load(tile_t& tile, payload_t& payload) {
181180
for (uint32_t i = 0; i < num_block_y; ++i) {
182181
constexpr uint32_t load_block_elems = block_elems * arr_len;
183182
int offset_y = i * block_size_y;
184-
// auto payload_row =
185-
// payload_2d.xetla_select<num_block_x, 1, 16, 1>(i * num_block_x, 0);
186-
// detail::reset_tile_desc_core<
187-
// num_block_x,
188-
// block_size_x,
189-
// ld_blk_size_y,
190-
// scale_factor,
191-
// arr_len,
192-
// mem_transpose>(payload_row);
193183
#pragma unroll
194184
for (uint32_t j = 0; j < num_block_x; j += arr_len) {
195185
int32_t offset_x = j * block_size_x;
196-
// xetla_tdescriptor tdesc = payload_row.row(j);
197186
auto reg_blk = tile.reg.xetla_select<load_block_elems, 1>(
198187
(i * num_block_x + j) * block_elems);
199188
constexpr uint32_t ld_blk_height = (reg_transpose && trans)
@@ -215,7 +204,8 @@ tile_load(tile_t& tile, payload_t& payload) {
215204
mem_transform,
216205
L1,
217206
L2>(
218-
payload.base_ptr,
207+
reinterpret_cast<const native_type_t<load_dtype*>>(
208+
payload.base_ptr),
219209
payload.surface_width,
220210
payload.surface_height,
221211
payload.surface_pitch,
@@ -273,7 +263,8 @@ tile_load(tile_t& tile, payload_t& payload) {
273263
mem_transform,
274264
L1,
275265
L2>(
276-
payload.base_ptr,
266+
reinterpret_cast<const native_type_t<load_dtype*>>(
267+
payload.base_ptr),
277268
payload.surface_width,
278269
payload.surface_height,
279270
payload.surface_pitch,
@@ -335,7 +326,8 @@ tile_load(tile_t& tile, payload_t& payload) {
335326
mem_transform,
336327
L1,
337328
L2>(
338-
payload.base_ptr,
329+
reinterpret_cast<const native_type_t<load_dtype*>>(
330+
payload.base_ptr),
339331
payload.surface_width,
340332
payload.surface_height,
341333
payload.surface_pitch,
@@ -402,7 +394,8 @@ tile_load(tile_t& tile, payload_t& payload) {
402394
mem_transform,
403395
L1,
404396
L2>(
405-
payload.base_ptr,
397+
reinterpret_cast<const native_type_t<load_dtype*>>(
398+
payload.base_ptr),
406399
payload.surface_width,
407400
payload.surface_height,
408401
payload.surface_pitch,

include/subgroup/tile/impl/payload_xe.hpp

+26-5
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,28 @@ struct mem_payload_t<
8585
using mem_dtype = typename std::
8686
conditional_t<mem_transpose_dtype_less4bytes, uint32_t, dtype>;
8787
static constexpr uint32_t scale_factor = sizeof(mem_dtype) / sizeof(dtype);
88-
mem_dtype* base_ptr;
88+
89+
using load_store_attr = load_store_attr_t<msg_type::block_2d, arch_tag>;
90+
91+
static constexpr uint32_t max_load_width_in_elem = trans
92+
? load_store_attr::max_trans_load_width_in_bytes / sizeof(dtype)
93+
: load_store_attr::max_load_width_in_bytes / sizeof(dtype);
94+
static constexpr uint32_t max_load_height_in_elem = trans
95+
? load_store_attr::max_trans_load_height_in_elem
96+
: load_store_attr::max_load_height_in_elem;
97+
98+
static constexpr uint32_t max_store_width_in_elem =
99+
load_store_attr::max_store_width_in_bytes / sizeof(dtype);
100+
static constexpr uint32_t max_store_height_in_elem =
101+
load_store_attr::max_store_height_in_elem;
102+
103+
static constexpr uint32_t elems_per_CL =
104+
load_store_attr::cache_line_size_in_bytes / sizeof(dtype);
105+
106+
static constexpr uint32_t elems_per_reg =
107+
register_bytes_t<arch_tag>::reg_in_bytes / sizeof(dtype);
108+
109+
dtype* base_ptr;
89110
uint32_t surface_width;
90111
uint32_t surface_height;
91112
uint32_t surface_pitch;
@@ -106,7 +127,7 @@ struct mem_payload_t<
106127
}
107128

108129
inline mem_payload_t(mem_desc_t& mem_desc) {
109-
this->base_ptr = (mem_dtype*)mem_desc.base.base;
130+
this->base_ptr = (dtype*)mem_desc.base.base;
110131
this->surface_width =
111132
(mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
112133
this->surface_height =
@@ -131,7 +152,7 @@ struct mem_payload_t<
131152
uint32_t surface_pitch,
132153
int32_t surface_offset_x = 0,
133154
int32_t surface_offset_y = 0) {
134-
this->base_ptr = (mem_dtype*)p;
155+
this->base_ptr = p;
135156
this->surface_width = surface_width * sizeof(dtype);
136157
this->surface_height = surface_height;
137158
this->surface_pitch = surface_pitch * sizeof(dtype);
@@ -152,7 +173,7 @@ struct mem_payload_t<
152173
}
153174

154175
__XETLA_API void init(mem_desc_t& mem_desc) {
155-
this->base_ptr = (mem_dtype*)mem_desc.base.base;
176+
this->base_ptr = (dtype*)mem_desc.base.base;
156177
this->surface_width =
157178
(mem_transpose ? mem_desc.shape.y : mem_desc.shape.x) * sizeof(dtype);
158179
this->surface_height =
@@ -185,7 +206,7 @@ struct mem_payload_t<
185206
uint32_t surface_pitch,
186207
int32_t surface_offset_x = 0,
187208
int32_t surface_offset_y = 0) {
188-
this->base_ptr = (mem_dtype*)p;
209+
this->base_ptr = p;
189210
this->surface_width = surface_width * sizeof(dtype);
190211
this->surface_height = surface_height;
191212
this->surface_pitch = surface_pitch * sizeof(dtype);

0 commit comments

Comments
 (0)