Skip to content

Commit

Permalink
improvements for device_vector
Browse files Browse the repository at this point in the history
  • Loading branch information
danhoeflinger committed Jul 8, 2024
1 parent c3fde81 commit c25deda
Showing 1 changed file with 5 additions and 23 deletions.
28 changes: 5 additions & 23 deletions help_function/src/onedpl_test_sort_by_key.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -201,40 +201,22 @@ int main() {

{
// Test Two, test calls to dpct::sort using device vectors
dpct::device_vector<int> keys_vec(10);
dpct::device_vector<int> values_vec(10);

std::vector<int> keys_data{4, 8, 5, 3, 0, 9, 7, 2, 1, 6};
std::vector<int> values_data{13, 16, 17, 11, 19, 14, 12, 18, 10, 15};

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(keys_vec.data(), keys_data.data(), 10 * sizeof(int));
});

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(values_vec.data(), values_data.data(), 10 * sizeof(int));
});
dpct::get_default_queue().wait();
dpct::device_vector<int> keys_vec(keys_data);
dpct::device_vector<int> values_vec(values_data);

auto keys_it = keys_vec.begin();
auto keys_it_end = keys_vec.end();
auto values_it = values_vec.begin();
{
// call algorithm
dpct::sort(oneapi::dpl::execution::make_device_policy<>(dpct::get_default_queue()), keys_it, keys_it_end, values_it);
dpct::sort(oneapi::dpl::execution::dpcpp_default, keys_it, keys_it_end, values_it);
// keys is now = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}
// values is now = {19, 10, 18, 11, 13, 17, 15, 12, 16, 14}
}

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(keys_data.data(), keys_vec.data(), 10 * sizeof(int));
});

dpct::get_default_queue().submit([&](sycl::handler& h) {
h.memcpy(values_data.data(), values_vec.data(), 10 * sizeof(int));
});
dpct::get_default_queue().wait();

{
int check_keys[10] = {0, 1, 2, 3, 4, 5, 6, 7, 8, 9};
int check_values[10] = {19, 10, 18, 11, 13, 17, 15, 12, 16, 14};
Expand All @@ -243,8 +225,8 @@ int main() {
// check that values and keys are correct

for (int i = 0; i != 10; ++i) {
num_failing += ASSERT_EQUAL(test_name, values_data[i], check_values[i]);
num_failing += ASSERT_EQUAL(test_name, keys_data[i], check_keys[i]);
num_failing += ASSERT_EQUAL(test_name, values_vec[i], check_values[i]);
num_failing += ASSERT_EQUAL(test_name, keys_vec[i], check_keys[i]);
}

failed_tests += test_passed(num_failing, test_name);
Expand Down

0 comments on commit c25deda

Please sign in to comment.