Skip to content

Commit

Permalink
Add test case for dpct::argmin and dpct::argmax
Browse files Browse the repository at this point in the history
Signed-off-by: Matthew Michel <matthew.michel@intel.com>
  • Loading branch information
mmichel11 committed Jun 4, 2024
1 parent 3aba14b commit d000025
Showing 1 changed file with 67 additions and 1 deletion.
68 changes: 67 additions & 1 deletion help_function/src/onedpl_test_reduce.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@

#include <iostream>
#include <vector>

#include <limits>



Expand Down Expand Up @@ -378,6 +378,72 @@ int main() {

}

// Testing calls to dpct::argmin and dpct::argmax functors with unique and equivalent values
{
auto queue = dpct::get_default_queue();
test_name = "oneapi::dpl::reduce with dpct::argmin functor - All values are unique";
std::size_t n = 10;
sycl::buffer<dpct::key_value_pair<int, float>> input(n);
{
auto host_acc = input.get_host_access();
for (std::size_t i = 0; i < n; ++i)
host_acc[i].key = i + 10, host_acc[i].value = i + 20;
// inject min and max
host_acc[4].key = 9;
host_acc[4].value = 8;
host_acc[6].key = 101;
host_acc[6].value = 99;
}
auto argmin_res = oneapi::dpl::reduce(dpl::execution::make_device_policy(queue),
oneapi::dpl::begin(input),
oneapi::dpl::end(input),
dpct::key_value_pair<int, float>(std::numeric_limits<int>::max(),
std::numeric_limits<float>::max()),
dpct::argmin());

failed_tests += ASSERT_EQUAL(test_name, argmin_res.key, 9);
failed_tests += ASSERT_EQUAL(test_name, argmin_res.value, 8);

test_name = "oneapi::dpl::reduce with dpct::argmax functor - All values are unique";
auto argmax_res = oneapi::dpl::reduce(dpl::execution::make_device_policy(queue),
oneapi::dpl::begin(input),
oneapi::dpl::end(input),
dpct::key_value_pair<int, float>(std::numeric_limits<int>::min(),
std::numeric_limits<float>::min()),
dpct::argmax());

failed_tests += ASSERT_EQUAL(test_name, argmax_res.key, 101);
failed_tests += ASSERT_EQUAL(test_name, argmax_res.value, 99);

test_name = "oneapi::dpl::reduce with dpct::argmin functor - All values are the same";
{
auto host_acc = input.get_host_access();
for (std::size_t i = 0; i < n; ++i)
host_acc[i].key = i + 30, host_acc[i].value = 2;
}
// Expect the key_value_pair with the lower key to be returned when value compares equal
argmin_res = oneapi::dpl::reduce(dpl::execution::make_device_policy(queue),
oneapi::dpl::begin(input),
oneapi::dpl::end(input),
dpct::key_value_pair<int, float>(std::numeric_limits<int>::max(),
std::numeric_limits<float>::max()),
dpct::argmin());

failed_tests += ASSERT_EQUAL(test_name, argmin_res.key, 30);
failed_tests += ASSERT_EQUAL(test_name, argmin_res.value, 2);

argmax_res = oneapi::dpl::reduce(dpl::execution::make_device_policy(queue),
oneapi::dpl::begin(input),
oneapi::dpl::end(input),
dpct::key_value_pair<int, float>(std::numeric_limits<int>::min(),
std::numeric_limits<float>::min()),
dpct::argmax());

// Expect the key_value_pair with the lower key to be returned when value compares equal
failed_tests += ASSERT_EQUAL(test_name, argmax_res.key, 30);
failed_tests += ASSERT_EQUAL(test_name, argmax_res.value, 2);
}

std::cout << std::endl << failed_tests << " failing test(s) detected." << std::endl;
if (failed_tests == 0) {
return 0;
Expand Down

0 comments on commit d000025

Please sign in to comment.