aboutsummaryrefslogtreecommitdiff
diff options
context:
space:
mode:
authorMikhail Dvorskiy <mikhail.dvorskiy@intel.com>2021-03-01 12:09:42 +0300
committerMikhail Dvorskiy <mikhail.dvorskiy@intel.com>2021-03-04 14:55:20 +0300
commitc5d827dd22376f425eb979eac50ee3f68a2e59fd (patch)
treed79f4716584a8e85a277625ea9e0382299f1127a
parentMerge branch 'main' into dev/mdvorski/range_api_sycl_buf (diff)
downloadllvm-project-c5d827dd22376f425eb979eac50ee3f68a2e59fd.tar.gz
llvm-project-c5d827dd22376f425eb979eac50ee3f68a2e59fd.tar.bz2
llvm-project-c5d827dd22376f425eb979eac50ee3f68a2e59fd.zip
[dpc++][ranges] + range API for any_of, all_of, none_of, adjacent_find, count
-rw-r--r--include/oneapi/dpl/pstl/glue_algorithm_ranges_defs.h34
-rw-r--r--include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h60
-rw-r--r--include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h10
-rw-r--r--include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h21
4 files changed, 121 insertions, 4 deletions
diff --git a/include/oneapi/dpl/pstl/glue_algorithm_ranges_defs.h b/include/oneapi/dpl/pstl/glue_algorithm_ranges_defs.h
index 3739e51fad27..163602d0eea7 100644
--- a/include/oneapi/dpl/pstl/glue_algorithm_ranges_defs.h
+++ b/include/oneapi/dpl/pstl/glue_algorithm_ranges_defs.h
@@ -27,6 +27,24 @@ namespace experimental
namespace ranges
{
+// [alg.any_of]
+
+template <typename _ExecutionPolicy, typename _Range, typename _Predicate>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, bool>
+any_of(_ExecutionPolicy&& __exec, _Range&& __rng, _Predicate __pred);
+
+// [alg.all_of]
+
+template <typename _ExecutionPolicy, typename _Range, typename _Predicate>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, bool>
+all_of(_ExecutionPolicy&& __exec, _Range&& __rng, _Predicate __pred);
+
+// [alg.none_of]
+
+template <typename _ExecutionPolicy, typename _Range, typename _Predicate>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, bool>
+none_of(_ExecutionPolicy&& __exec, _Range&& __rng, _Predicate __pred);
+
// [alg.foreach]
template <typename _ExecutionPolicy, typename _Range, typename _Function>
@@ -71,6 +89,22 @@ oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy,
oneapi::dpl::__internal::__difference_t<_Range1>>
find_first_of(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2);
+// [alg.adjacent_find]
+
+template <typename _ExecutionPolicy, typename _Range>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, oneapi::dpl::__internal::__difference_t<_Range>>
+adjacent_find(_ExecutionPolicy&& __exec, _Range&& __rng);
+
+template <typename _ExecutionPolicy, typename _Range, typename _BinaryPredicate>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, oneapi::dpl::__internal::__difference_t<_Range>>
+adjacent_find(_ExecutionPolicy&& __exec, _Range&& __rng, _BinaryPredicate __pred);
+
+// [alg.count]
+
+template <typename _ExecutionPolicy, typename _Range, typename _Tp>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, oneapi::dpl::__internal::__difference_t<_Range>>
+count(_ExecutionPolicy&& __exec, _Range&& __rng, const _Tp& __value);
+
// [alg.search]
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _BinaryPredicate>
diff --git a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
index 46ca4a9454f8..4cb5991eb772 100644
--- a/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
+++ b/include/oneapi/dpl/pstl/glue_algorithm_ranges_impl.h
@@ -33,6 +33,36 @@ namespace experimental
namespace ranges
{
+// [alg.any_of]
+
+template <typename _ExecutionPolicy, typename _Range, typename _Predicate>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, bool>
+any_of(_ExecutionPolicy&& __exec, _Range&& __rng, _Predicate __pred)
+{
+ return oneapi::dpl::__internal::__ranges::__pattern_any_of(::std::forward<_ExecutionPolicy>(__exec),
+ views::all(::std::forward<_Range>(__rng)), __pred);
+}
+
+// [alg.all_of]
+
+template <typename _ExecutionPolicy, typename _Range, typename _Predicate>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, bool>
+all_of(_ExecutionPolicy&& __exec, _Range&& __rng, _Predicate __pred)
+{
+ return !any_of(
+ ::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng),
+ oneapi::dpl::__internal::__not_pred<oneapi::dpl::__internal::__ref_or_copy<_ExecutionPolicy, _Pred>>(__pred));
+}
+
+// [alg.none_of]
+
+template <typename _ExecutionPolicy, typename _Range, typename _Predicate>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, bool>
+none_of(_ExecutionPolicy&& __exec, _Range&& __rng, _Predicate __pred)
+{
+ return !any_of(::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), __pred);
+}
+
// [alg.foreach]
template <typename _ExecutionPolicy, typename _Range, typename _Function>
@@ -115,6 +145,36 @@ find_first_of(_ExecutionPolicy&& __exec, _Range1&& __rng1, _Range2&& __rng2)
::std::forward<_Range2>(__rng2), oneapi::dpl::__internal::__pstl_equal());
}
+// [alg.adjacent_find]
+
+template <typename _ExecutionPolicy, typename _Range, typename _BinaryPredicate>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, oneapi::dpl::__internal::__difference_t<_Range>>
+adjacent_find(_ExecutionPolicy&& __exec, _Range&& __rng, _BinaryPredicate __pred)
+{
+ return oneapi::dpl::__internal::__ranges::__pattern_adjacent_find(
+ ::std::forward<_ExecutionPolicy>(__exec), views::all_read(::std::forward<_Range>(__rng)),
+ __pred, oneapi::dpl::__internal::__first_semantic());
+}
+
+template <typename _ExecutionPolicy, typename _Range>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, oneapi::dpl::__internal::__difference_t<_Range>>
+adjacent_find(_ExecutionPolicy&& __exec, _Range&& __rng)
+{
+ using _ValueType = oneapi::dpl::__internal::__value_t<_Range>;
+ return adjacent_find(::std::forward<_ExecutionPolicy>(__exec), ::std::forward<_Range>(__rng), ::std::equal_to<_ValueType>());
+}
+
+// [alg.count]
+
+template <typename _ExecutionPolicy, typename _Range, typename _Tp>
+oneapi::dpl::__internal::__enable_if_execution_policy<_ExecutionPolicy, oneapi::dpl::__internal::__difference_t<_Range>>
+count(_ExecutionPolicy&& __exec, _Range&& __rng, const _Tp& __value)
+{
+ return oneapi::dpl::__internal::__ranges::__pattern_count(::std::forward<_ExecutionPolicy>(__exec),
+ views::all_read(::std::forward<_Range>(__rng)),
+ oneapi::dpl::__internal::__equal_value<oneapi::dpl::__internal::__ref_or_copy<_ExecutionPolicy, const _Tp>>(__value));
+}
+
// [alg.search]
template <typename _ExecutionPolicy, typename _Range1, typename _Range2, typename _BinaryPredicate>
diff --git a/include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h b/include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h
index eb41600c2ad6..302a2803fe0e 100644
--- a/include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h
+++ b/include/oneapi/dpl/pstl/hetero/algorithm_impl_hetero.h
@@ -640,10 +640,12 @@ __pattern_any_of(_ExecutionPolicy&& __exec, _Iterator __first, _Iterator __last,
using _Predicate = oneapi::dpl::unseq_backend::single_match_pred<_ExecutionPolicy, _Pred>;
- return __par_backend_hetero::__parallel_or(
- ::std::forward<_ExecutionPolicy>(__exec),
- __par_backend_hetero::make_iter_mode<__par_backend_hetero::access_mode::read>(__first),
- __par_backend_hetero::make_iter_mode<__par_backend_hetero::access_mode::read>(__last), _Predicate{__pred});
+ auto __keep = oneapi::dpl::__ranges::__get_sycl_range<__par_backend_hetero::access_mode::read, _Iterator>();
+ auto __buf = __keep(__first, __last);
+
+ return oneapi::dpl::__par_backend_hetero::__parallel_find_or(
+ __par_backend_hetero::make_wrapped_policy<__or_policy_wrapper>(::std::forward<_ExecutionPolicy>(__exec)),
+ _Predicate{__pred}, __parallel_or_tag{}, __buf.all_view());
}
//------------------------------------------------------------------------
diff --git a/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h b/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
index 8ecde3d294aa..7f2214ecb1b5 100644
--- a/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
+++ b/include/oneapi/dpl/pstl/hetero/algorithm_ranges_impl_hetero.h
@@ -294,6 +294,27 @@ __pattern_adjacent_find(_ExecutionPolicy&& __exec, _Range&& __rng, _BinaryPredic
return return_value(result, __rng.size(), __is__or_semantic);
}
+template <typename _ExecutionPolicy, typename _Range, typename _Predicate>
+oneapi::dpl::__internal::__enable_if_hetero_execution_policy<_ExecutionPolicy, oneapi::dpl::__internal::__difference_t<_Range>>
+__pattern_count(_ExecutionPolicy&& __exec, _Range&& __rng, _Predicate __predicate)
+{
+ if (__rng.size() == 0)
+ return 0;
+
+ using _ReduceValueType = oneapi::dpl::__internal::__difference_t<_Range>;
+
+ auto __identity_init_fn = acc_handler_count<_Predicate>{__predicate};
+ auto __identity_reduce_fn = ::std::plus<_ReduceValueType>{};
+
+ return oneapi::dpl::__par_backend_hetero::__parallel_transform_reduce<_ReduceValueType>(
+ ::std::forward<_ExecutionPolicy>(__exec),
+ unseq_backend::transform_init<_ExecutionPolicy, decltype(__identity_reduce_fn), decltype(__identity_init_fn)>{
+ __identity_reduce_fn, __identity_init_fn},
+ __identity_reduce_fn,
+ unseq_backend::reduce<_ExecutionPolicy, decltype(__identity_reduce_fn), _ReduceValueType>{__identity_reduce_fn},
+ ::std::forward<_Range>(__rng));
+}
+
//------------------------------------------------------------------------
// merge
//------------------------------------------------------------------------