Skip to content

Commit 3b96257

Browse files
[GPU] Reinterpret from 1 dim mem to 0 dim mem instead of allocating 0 bytes layout to OpenCL
1 parent bc505ba commit 3b96257

File tree

1 file changed

+26
-5
lines changed

1 file changed

+26
-5
lines changed

src/plugins/intel_gpu/src/runtime/ocl/ocl_engine.cpp

+26-5
Original file line numberDiff line numberDiff line change
@@ -172,25 +172,46 @@ bool ocl_engine::check_allocatable(const layout& layout, allocation_type type) {
172172
memory::ptr ocl_engine::allocate_memory(const layout& layout, allocation_type type, bool reset) {
173173
OPENVINO_ASSERT(!layout.is_dynamic() || layout.has_upper_bound(), "[GPU] Can't allocate memory for dynamic layout");
174174

175-
check_allocatable(layout, type);
175+
if (!check_allocatable(layout, type))
176+
return nullptr;
177+
178+
auto zero_bytes_layout = false;
179+
auto non_zero_layout = layout;
180+
if (layout.bytes_count() == 0) {
181+
cldnn::layout zero_dim_layout = layout;
182+
auto mem_ps = zero_dim_layout.get_partial_shape();
183+
for (size_t k = 0; k < mem_ps.size(); k++) {
184+
if (mem_ps[k] == 0)
185+
mem_ps[k] = 1;
186+
}
187+
188+
non_zero_layout = cldnn::layout(mem_ps, zero_dim_layout.data_type, zero_dim_layout.format);
189+
zero_bytes_layout = true;
190+
}
191+
192+
176193

177194
try {
178195
memory::ptr res = nullptr;
179196
if (layout.format.is_image_2d()) {
180-
res = std::make_shared<ocl::gpu_image2d>(this, layout);
197+
res = std::make_shared<ocl::gpu_image2d>(this, non_zero_layout);
181198
} else if (type == allocation_type::cl_mem) {
182-
res = std::make_shared<ocl::gpu_buffer>(this, layout);
199+
res = std::make_shared<ocl::gpu_buffer>(this, non_zero_layout);
183200
} else {
184-
res = std::make_shared<ocl::gpu_usm>(this, layout, type);
201+
res = std::make_shared<ocl::gpu_usm>(this, non_zero_layout, type);
185202
}
186203

187-
if (reset || res->is_memory_reset_needed(layout)) {
204+
if (reset || res->is_memory_reset_needed(non_zero_layout)) {
188205
auto ev = res->fill(get_service_stream());
189206
if (ev) {
190207
get_service_stream().wait_for_events({ev});
191208
}
192209
}
193210

211+
if (zero_bytes_layout) {
212+
res = reinterpret_buffer(*res, layout);
213+
}
214+
194215
return res;
195216
} catch (const cl::Error& clErr) {
196217
switch (clErr.err()) {

0 commit comments

Comments
 (0)