cublas_wrapper.hpp
Go to the documentation of this file.
1/*
2 Copyright 2022-2023 SINTEF AS
3
4 This file is part of the Open Porous Media project (OPM).
5
6 OPM is free software: you can redistribute it and/or modify
7 it under the terms of the GNU General Public License as published by
8 the Free Software Foundation, either version 3 of the License, or
9 (at your option) any later version.
10
11 OPM is distributed in the hope that it will be useful,
12 but WITHOUT ANY WARRANTY; without even the implied warranty of
13 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14 GNU General Public License for more details.
15
16 You should have received a copy of the GNU General Public License
17 along with OPM. If not, see <http://www.gnu.org/licenses/>.
18*/
19
27#ifndef OPM_CUBLASWRAPPER_HEADER_INCLUDED
28#define OPM_CUBLASWRAPPER_HEADER_INCLUDED
29#include <cublas_v2.h>
30#include <opm/common/ErrorMacros.hpp>
31
32namespace Opm::cuistl::detail
33{
34
35inline cublasStatus_t
36cublasScal(cublasHandle_t handle,
37 int n,
38 const double* alpha, /* host or device pointer */
39 double* x,
40 int incx)
41{
42 return cublasDscal(handle,
43 n,
44 alpha, /* host or device pointer */
45 x,
46 incx);
47}
48
49inline cublasStatus_t
50cublasScal(cublasHandle_t handle,
51 int n,
52 const float* alpha, /* host or device pointer */
53 float* x,
54 int incx)
55{
56 return cublasSscal(handle,
57 n,
58 alpha, /* host or device pointer */
59 x,
60 incx);
61}
62
63inline cublasStatus_t
64cublasScal([[maybe_unused]] cublasHandle_t handle,
65 [[maybe_unused]] int n,
66 [[maybe_unused]] const int* alpha, /* host or device pointer */
67 [[maybe_unused]] int* x,
68 [[maybe_unused]] int incx)
69{
70 OPM_THROW(std::runtime_error, "cublasScal multiplication for integer vectors is not implemented yet.");
71}
72inline cublasStatus_t
73cublasAxpy(cublasHandle_t handle,
74 int n,
75 const double* alpha, /* host or device pointer */
76 const double* x,
77 int incx,
78 double* y,
79 int incy)
80{
81 return cublasDaxpy(handle,
82 n,
83 alpha, /* host or device pointer */
84 x,
85 incx,
86 y,
87 incy);
88}
89
90inline cublasStatus_t
91cublasAxpy(cublasHandle_t handle,
92 int n,
93 const float* alpha, /* host or device pointer */
94 const float* x,
95 int incx,
96 float* y,
97 int incy)
98{
99 return cublasSaxpy(handle,
100 n,
101 alpha, /* host or device pointer */
102 x,
103 incx,
104 y,
105 incy);
106}
107
108inline cublasStatus_t
109cublasAxpy([[maybe_unused]] cublasHandle_t handle,
110 [[maybe_unused]] int n,
111 [[maybe_unused]] const int* alpha, /* host or device pointer */
112 [[maybe_unused]] const int* x,
113 [[maybe_unused]] int incx,
114 [[maybe_unused]] int* y,
115 [[maybe_unused]] int incy)
116{
117 OPM_THROW(std::runtime_error, "axpy multiplication for integer vectors is not implemented yet.");
118}
119
120inline cublasStatus_t
121cublasDot(cublasHandle_t handle, int n, const double* x, int incx, const double* y, int incy, double* result)
122{
123 return cublasDdot(handle, n, x, incx, y, incy, result);
124}
125
126inline cublasStatus_t
127cublasDot(cublasHandle_t handle, int n, const float* x, int incx, const float* y, int incy, float* result)
128{
129 return cublasSdot(handle, n, x, incx, y, incy, result);
130}
131
132inline cublasStatus_t
133cublasDot([[maybe_unused]] cublasHandle_t handle,
134 [[maybe_unused]] int n,
135 [[maybe_unused]] const int* x,
136 [[maybe_unused]] int incx,
137 [[maybe_unused]] const int* y,
138 [[maybe_unused]] int incy,
139 [[maybe_unused]] int* result)
140{
141 OPM_THROW(std::runtime_error, "inner product for integer vectors is not implemented yet.");
142}
143
144inline cublasStatus_t
145cublasNrm2(cublasHandle_t handle, int n, const double* x, int incx, double* result)
146{
147 return cublasDnrm2(handle, n, x, incx, result);
148}
149
150
151inline cublasStatus_t
152cublasNrm2(cublasHandle_t handle, int n, const float* x, int incx, float* result)
153{
154 return cublasSnrm2(handle, n, x, incx, result);
155}
156
157inline cublasStatus_t
158cublasNrm2([[maybe_unused]] cublasHandle_t handle,
159 [[maybe_unused]] int n,
160 [[maybe_unused]] const int* x,
161 [[maybe_unused]] int incx,
162 [[maybe_unused]] int* result)
163{
164 OPM_THROW(std::runtime_error, "norm2 for integer vectors is not implemented yet.");
165}
166
167} // namespace Opm::cuistl::detail
168#endif
Definition: cublas_safe_call.hpp:32
cublasStatus_t cublasDot(cublasHandle_t handle, int n, const double *x, int incx, const double *y, int incy, double *result)
Definition: cublas_wrapper.hpp:121
cublasStatus_t cublasScal(cublasHandle_t handle, int n, const double *alpha, double *x, int incx)
Definition: cublas_wrapper.hpp:36
cublasStatus_t cublasNrm2(cublasHandle_t handle, int n, const double *x, int incx, double *result)
Definition: cublas_wrapper.hpp:145
cublasStatus_t cublasAxpy(cublasHandle_t handle, int n, const double *alpha, const double *x, int incx, double *y, int incy)
Definition: cublas_wrapper.hpp:73