Skip to content

Commit a595116

Browse files
committed
graph: interface: fix memory size for u4/s4
1 parent 6efa272 commit a595116

File tree

3 files changed

+18
-3
lines changed

3 files changed

+18
-3
lines changed

src/graph/interface/logical_tensor.cpp

+4-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2020-2023 Intel Corporation
2+
* Copyright 2020-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -38,7 +38,9 @@ size_t logical_tensor_wrapper_t::size() const {
3838
static_cast<size_t>(strided_pdim * effective_stride));
3939
}
4040

41-
return max_size * data_type_size();
41+
size_t data_size = utils::div_up(
42+
max_size * data_type_size(), sub_byte_data_type_multiplier());
43+
return data_size;
4244
} else if (is_opaque()) {
4345
size_t layout_id = lt->layout.layout_id;
4446
auto backend

src/graph/interface/logical_tensor.hpp

+8-1
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
/*******************************************************************************
2-
* Copyright 2020-2023 Intel Corporation
2+
* Copyright 2020-2024 Intel Corporation
33
*
44
* Licensed under the Apache License, Version 2.0 (the "License");
55
* you may not use this file except in compliance with the License.
@@ -154,6 +154,13 @@ struct logical_tensor_wrapper_t {
154154
/* check_dtype = */ true);
155155
}
156156

157+
/** For sub-byte data types returns number of elements per byte.
158+
* For the rest data types returns 1. */
159+
size_t sub_byte_data_type_multiplier() const {
160+
if (utils::one_of(data_type(), data_type::s4, data_type::u4)) return 2;
161+
return 1;
162+
}
163+
157164
// return the size of data type
158165
size_t data_type_size() const { return types::data_type_size(data_type()); }
159166

tests/gtests/graph/api/test_cpp_api_logical_tensor.cpp

+6
Original file line numberDiff line numberDiff line change
@@ -311,4 +311,10 @@ TEST(APILogicalTensor, LogicalTensorSize) {
311311
ASSERT_EQ(lt_3.get_id(), id);
312312
ASSERT_EQ(lt_3.get_data_type(), data_type::s8);
313313
ASSERT_EQ(lt_3.get_mem_size(), num_elem * sizeof(int8_t));
314+
315+
logical_tensor lt_4 {id, data_type::s4, shape, layout_type::strided};
316+
ASSERT_EQ(lt_4.get_id(), id);
317+
ASSERT_EQ(lt_4.get_data_type(), data_type::s4);
318+
// in case num_elem is not even.
319+
ASSERT_EQ(lt_4.get_mem_size(), (num_elem + 1) / 2);
314320
}

0 commit comments

Comments
 (0)