Loading...
Searching...
No Matches
test_util.cuh
1/*
2 * Copyright (c) 2022-2025, NVIDIA CORPORATION.
3 *
4 * Licensed under the Apache License, Version 2.0 (the "License");
5 * you may not use this file except in compliance with the License.
6 * You may obtain a copy of the License at
7 *
8 * http://www.apache.org/licenses/LICENSE-2.0
9 *
10 * Unless required by applicable law or agreed to in writing, software
11 * distributed under the License is distributed on an "AS IS" BASIS,
12 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13 * See the License for the specific language governing permissions and
14 * limitations under the License.
15 */
16
17#pragma once
18
20#include <cuspatial/traits.hpp>
21
22#include <rmm/device_uvector.hpp>
23
24#include <thrust/for_each.h>
25#include <thrust/host_vector.h>
26
27#include <cstdio>
28#include <iomanip>
29#include <string_view>
30
31namespace cuspatial {
32
33namespace test {
34
46template <typename T, typename Vector>
47thrust::host_vector<T> to_host(Vector const& dvec)
48{
49 if constexpr (std::is_same_v<Vector, rmm::device_uvector<T>>) {
50 thrust::host_vector<T> hvec(dvec.size());
51 cudaMemcpyAsync(hvec.data(),
52 dvec.data(),
53 dvec.size() * sizeof(T),
54 cudaMemcpyKind::cudaMemcpyDeviceToHost,
55 dvec.stream());
56 dvec.stream().synchronize();
57 return hvec;
58 } else {
59 return thrust::host_vector<T>(dvec);
60 }
61}
62
72template <typename Iter, typename T = cuspatial::iterator_value_type<Iter>>
73thrust::host_vector<T> to_host(Iter begin, Iter end)
74{
75 return thrust::host_vector<T>(begin, end);
76}
77
89template <typename Iter>
90void print_device_range(Iter begin,
91 Iter end,
92 std::string_view pre = "",
93 std::string_view post = "\n")
94{
95 auto hvec = to_host(begin, end);
96
97 std::cout << pre;
98 std::for_each(hvec.begin(), hvec.end(), [](auto const& x) { std::cout << x << " "; });
99 std::cout << post;
100}
101
112template <typename Vector>
113void print_device_vector(Vector const& vec, std::string_view pre = "", std::string_view post = "\n")
114{
115 using T = typename Vector::value_type;
116 auto hvec = to_host<T>(vec);
117
118 std::cout << pre;
119 std::for_each(hvec.begin(), hvec.end(), [](auto const& x) { std::cout << x << " "; });
120 std::cout << post;
121}
122
123} // namespace test
124} // namespace cuspatial