opm-simulators
cusparse_wrapper.hpp
1 /*
2  Copyright 2022-2023 SINTEF AS
3 
4  This file is part of the Open Porous Media project (OPM).
5 
6  OPM is free software: you can redistribute it and/or modify
7  it under the terms of the GNU General Public License as published by
8  the Free Software Foundation, either version 3 of the License, or
9  (at your option) any later version.
10 
11  OPM is distributed in the hope that it will be useful,
12  but WITHOUT ANY WARRANTY; without even the implied warranty of
13  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
14  GNU General Public License for more details.
15 
16  You should have received a copy of the GNU General Public License
17  along with OPM. If not, see <http://www.gnu.org/licenses/>.
18 */
19 
26 #include <cusparse.h>
27 #include <type_traits>
28 #ifndef OPM_CUSPARSE_WRAPPER_HPP
29 #define OPM_CUSPARSE_WRAPPER_HPP
30 namespace Opm::gpuistl::detail
31 {
32 
33 inline cusparseStatus_t
34 cusparseBsrilu02_analysis(cusparseHandle_t handle,
35  cusparseDirection_t dirA,
36  int mb,
37  int nnzb,
38  const cusparseMatDescr_t descrA,
39  double* bsrSortedVal,
40  const int* bsrSortedRowPtr,
41  const int* bsrSortedColInd,
42  int blockDim,
43  bsrilu02Info_t info,
44  cusparseSolvePolicy_t policy,
45  void* pBuffer)
46 {
47  return cusparseDbsrilu02_analysis(handle,
48  dirA,
49  mb,
50  nnzb,
51  descrA,
52  bsrSortedVal,
53  bsrSortedRowPtr,
54  bsrSortedColInd,
55  blockDim,
56  info,
57  policy,
58  pBuffer);
59 }
60 
61 inline cusparseStatus_t
62 cusparseBsrsv2_analysis(cusparseHandle_t handle,
63  cusparseDirection_t dirA,
64  cusparseOperation_t transA,
65  int mb,
66  int nnzb,
67  const cusparseMatDescr_t descrA,
68  const double* bsrSortedValA,
69  const int* bsrSortedRowPtrA,
70  const int* bsrSortedColIndA,
71  int blockDim,
72  bsrsv2Info_t info,
73  cusparseSolvePolicy_t policy,
74  void* pBuffer)
75 {
76  return cusparseDbsrsv2_analysis(handle,
77  dirA,
78  transA,
79  mb,
80  nnzb,
81  descrA,
82  bsrSortedValA,
83  bsrSortedRowPtrA,
84  bsrSortedColIndA,
85  blockDim,
86  info,
87  policy,
88  pBuffer);
89 }
90 
91 inline cusparseStatus_t
92 cusparseBsrsv2_analysis(cusparseHandle_t handle,
93  cusparseDirection_t dirA,
94  cusparseOperation_t transA,
95  int mb,
96  int nnzb,
97  const cusparseMatDescr_t descrA,
98  const float* bsrSortedValA,
99  const int* bsrSortedRowPtrA,
100  const int* bsrSortedColIndA,
101  int blockDim,
102  bsrsv2Info_t info,
103  cusparseSolvePolicy_t policy,
104  void* pBuffer)
105 {
106  return cusparseSbsrsv2_analysis(handle,
107  dirA,
108  transA,
109  mb,
110  nnzb,
111  descrA,
112  bsrSortedValA,
113  bsrSortedRowPtrA,
114  bsrSortedColIndA,
115  blockDim,
116  info,
117  policy,
118  pBuffer);
119 }
120 
121 inline cusparseStatus_t
122 cusparseBsrilu02_analysis(cusparseHandle_t handle,
123  cusparseDirection_t dirA,
124  int mb,
125  int nnzb,
126  const cusparseMatDescr_t descrA,
127  float* bsrSortedVal,
128  const int* bsrSortedRowPtr,
129  const int* bsrSortedColInd,
130  int blockDim,
131  bsrilu02Info_t info,
132  cusparseSolvePolicy_t policy,
133  void* pBuffer)
134 {
135  return cusparseSbsrilu02_analysis(handle,
136  dirA,
137  mb,
138  nnzb,
139  descrA,
140  bsrSortedVal,
141  bsrSortedRowPtr,
142  bsrSortedColInd,
143  blockDim,
144  info,
145  policy,
146  pBuffer);
147 }
148 
149 inline cusparseStatus_t
150 cusparseBsrsv2_solve(cusparseHandle_t handle,
151  cusparseDirection_t dirA,
152  cusparseOperation_t transA,
153  int mb,
154  int nnzb,
155  const double* alpha,
156  const cusparseMatDescr_t descrA,
157  const double* bsrSortedValA,
158  const int* bsrSortedRowPtrA,
159  const int* bsrSortedColIndA,
160  int blockDim,
161  bsrsv2Info_t info,
162  const double* f,
163  double* x,
164  cusparseSolvePolicy_t policy,
165  void* pBuffer)
166 {
167  return cusparseDbsrsv2_solve(handle,
168  dirA,
169  transA,
170  mb,
171  nnzb,
172  alpha,
173  descrA,
174  bsrSortedValA,
175  bsrSortedRowPtrA,
176  bsrSortedColIndA,
177  blockDim,
178  info,
179  f,
180  x,
181  policy,
182  pBuffer);
183 }
184 
185 
186 inline cusparseStatus_t
187 cusparseBsrsv2_solve(cusparseHandle_t handle,
188  cusparseDirection_t dirA,
189  cusparseOperation_t transA,
190  int mb,
191  int nnzb,
192  const float* alpha,
193  const cusparseMatDescr_t descrA,
194  const float* bsrSortedValA,
195  const int* bsrSortedRowPtrA,
196  const int* bsrSortedColIndA,
197  int blockDim,
198  bsrsv2Info_t info,
199  const float* f,
200  float* x,
201  cusparseSolvePolicy_t policy,
202  void* pBuffer)
203 {
204  return cusparseSbsrsv2_solve(handle,
205  dirA,
206  transA,
207  mb,
208  nnzb,
209  alpha,
210  descrA,
211  bsrSortedValA,
212  bsrSortedRowPtrA,
213  bsrSortedColIndA,
214  blockDim,
215  info,
216  f,
217  x,
218  policy,
219  pBuffer);
220 }
221 
222 
223 inline cusparseStatus_t
224 cusparseBsrilu02_bufferSize(cusparseHandle_t handle,
225  cusparseDirection_t dirA,
226  int mb,
227  int nnzb,
228  const cusparseMatDescr_t descrA,
229  double* bsrSortedVal,
230  const int* bsrSortedRowPtr,
231  const int* bsrSortedColInd,
232  int blockDim,
233  bsrilu02Info_t info,
234  int* pBufferSizeInBytes)
235 {
236  return cusparseDbsrilu02_bufferSize(handle,
237  dirA,
238  mb,
239  nnzb,
240  descrA,
241  bsrSortedVal,
242  bsrSortedRowPtr,
243  bsrSortedColInd,
244  blockDim,
245  info,
246  pBufferSizeInBytes);
247 }
248 
249 
250 inline cusparseStatus_t
251 cusparseBsrilu02_bufferSize(cusparseHandle_t handle,
252  cusparseDirection_t dirA,
253  int mb,
254  int nnzb,
255  const cusparseMatDescr_t descrA,
256  float* bsrSortedVal,
257  const int* bsrSortedRowPtr,
258  const int* bsrSortedColInd,
259  int blockDim,
260  bsrilu02Info_t info,
261  int* pBufferSizeInBytes)
262 {
263  return cusparseSbsrilu02_bufferSize(handle,
264  dirA,
265  mb,
266  nnzb,
267  descrA,
268  bsrSortedVal,
269  bsrSortedRowPtr,
270  bsrSortedColInd,
271  blockDim,
272  info,
273  pBufferSizeInBytes);
274 }
275 
276 inline cusparseStatus_t
277 cusparseBsrsv2_bufferSize(cusparseHandle_t handle,
278  cusparseDirection_t dirA,
279  cusparseOperation_t transA,
280  int mb,
281  int nnzb,
282  const cusparseMatDescr_t descrA,
283  double* bsrSortedValA,
284  const int* bsrSortedRowPtrA,
285  const int* bsrSortedColIndA,
286  int blockDim,
287  bsrsv2Info_t info,
288  int* pBufferSizeInBytes)
289 {
290  return cusparseDbsrsv2_bufferSize(handle,
291  dirA,
292  transA,
293  mb,
294  nnzb,
295  descrA,
296  bsrSortedValA,
297  bsrSortedRowPtrA,
298  bsrSortedColIndA,
299  blockDim,
300  info,
301  pBufferSizeInBytes);
302 }
303 inline cusparseStatus_t
304 cusparseBsrsv2_bufferSize(cusparseHandle_t handle,
305  cusparseDirection_t dirA,
306  cusparseOperation_t transA,
307  int mb,
308  int nnzb,
309  const cusparseMatDescr_t descrA,
310  float* bsrSortedValA,
311  const int* bsrSortedRowPtrA,
312  const int* bsrSortedColIndA,
313  int blockDim,
314  bsrsv2Info_t info,
315  int* pBufferSizeInBytes)
316 {
317  return cusparseSbsrsv2_bufferSize(handle,
318  dirA,
319  transA,
320  mb,
321  nnzb,
322  descrA,
323  bsrSortedValA,
324  bsrSortedRowPtrA,
325  bsrSortedColIndA,
326  blockDim,
327  info,
328  pBufferSizeInBytes);
329 }
330 
331 inline cusparseStatus_t
332 cusparseBsrilu02(cusparseHandle_t handle,
333  cusparseDirection_t dirA,
334  int mb,
335  int nnzb,
336  const cusparseMatDescr_t descrA,
337  double* bsrSortedVal,
338  const int* bsrSortedRowPtr,
339  const int* bsrSortedColInd,
340  int blockDim,
341  bsrilu02Info_t info,
342  cusparseSolvePolicy_t policy,
343  void* pBuffer)
344 {
345  return cusparseDbsrilu02(handle,
346  dirA,
347  mb,
348  nnzb,
349  descrA,
350  bsrSortedVal,
351  bsrSortedRowPtr,
352  bsrSortedColInd,
353  blockDim,
354  info,
355  policy,
356  pBuffer);
357 }
358 inline cusparseStatus_t
359 cusparseBsrilu02(cusparseHandle_t handle,
360  cusparseDirection_t dirA,
361  int mb,
362  int nnzb,
363  const cusparseMatDescr_t descrA,
364  float* bsrSortedVal,
365  const int* bsrSortedRowPtr,
366  const int* bsrSortedColInd,
367  int blockDim,
368  bsrilu02Info_t info,
369  cusparseSolvePolicy_t policy,
370  void* pBuffer)
371 {
372  return cusparseSbsrilu02(handle,
373  dirA,
374  mb,
375  nnzb,
376  descrA,
377  bsrSortedVal,
378  bsrSortedRowPtr,
379  bsrSortedColInd,
380  blockDim,
381  info,
382  policy,
383  pBuffer);
384 }
385 
386 inline cusparseStatus_t
387 cusparseBsrmv(cusparseHandle_t handle,
388  cusparseDirection_t dirA,
389  cusparseOperation_t transA,
390  int mb,
391  int nb,
392  int nnzb,
393  const double* alpha,
394  const cusparseMatDescr_t descrA,
395  const double* bsrSortedValA,
396  const int* bsrSortedRowPtrA,
397  const int* bsrSortedColIndA,
398  int blockDim,
399  const double* x,
400  const double* beta,
401  double* y)
402 {
403  return cusparseDbsrmv(handle,
404  dirA,
405  transA,
406  mb,
407  nb,
408  nnzb,
409  alpha,
410  descrA,
411  bsrSortedValA,
412  bsrSortedRowPtrA,
413  bsrSortedColIndA,
414  blockDim,
415  x,
416  beta,
417  y);
418 }
419 
420 inline cusparseStatus_t
421 cusparseBsrmv(cusparseHandle_t handle,
422  cusparseDirection_t dirA,
423  cusparseOperation_t transA,
424  int mb,
425  int nb,
426  int nnzb,
427  const float* alpha,
428  const cusparseMatDescr_t descrA,
429  const float* bsrSortedValA,
430  const int* bsrSortedRowPtrA,
431  const int* bsrSortedColIndA,
432  int blockDim,
433  const float* x,
434  const float* beta,
435  float* y)
436 {
437  return cusparseSbsrmv(handle,
438  dirA,
439  transA,
440  mb,
441  nb,
442  nnzb,
443  alpha,
444  descrA,
445  bsrSortedValA,
446  bsrSortedRowPtrA,
447  bsrSortedColIndA,
448  blockDim,
449  x,
450  beta,
451  y);
452 }
453 } // namespace Opm::gpuistl::detail
454 #endif
Contains wrappers to make the CuBLAS library behave as a modern C++ library with function overlading...
Definition: autotuner.hpp:29