1
- use crate :: { error:: * , sys} ;
2
- use cust:: stream:: Stream ;
3
1
use std:: ffi:: CString ;
4
2
use std:: mem:: { self , MaybeUninit } ;
5
3
use std:: os:: raw:: c_char;
6
4
use std:: ptr;
7
5
8
- type Result < T , E = Error > = std:: result:: Result < T , E > ;
6
+ use cust:: stream:: Stream ;
7
+ use cust_raw:: cublas_sys;
8
+ use cust_raw:: driver_sys;
9
+
10
+ use super :: error:: DropResult ;
11
+ use super :: error:: ToResult as _;
12
+
13
+ type Result < T , E = super :: error:: Error > = std:: result:: Result < T , E > ;
9
14
10
15
bitflags:: bitflags! {
11
16
/// Configures precision levels for the math in cuBLAS.
12
- #[ derive( Default ) ]
17
+ #[ derive( Debug , Default , Clone , Copy , PartialEq , Eq , Hash ) ]
13
18
pub struct MathMode : u32 {
14
19
/// Highest performance mode which uses compute and intermediate storage precisions
15
20
/// with at least the same number of mantissa and exponent bits as requested. Will
@@ -68,7 +73,7 @@ bitflags::bitflags! {
68
73
/// - [Matrix Multiplication <span style="float:right;">`gemm`</span>](CublasContext::gemm)
69
74
#[ derive( Debug ) ]
70
75
pub struct CublasContext {
71
- pub ( crate ) raw : sys :: v2 :: cublasHandle_t ,
76
+ pub ( crate ) raw : cublas_sys :: cublasHandle_t ,
72
77
}
73
78
74
79
impl CublasContext {
@@ -87,10 +92,10 @@ impl CublasContext {
87
92
pub fn new ( ) -> Result < Self > {
88
93
let mut raw = MaybeUninit :: uninit ( ) ;
89
94
unsafe {
90
- sys :: v2 :: cublasCreate_v2 ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
91
- sys :: v2 :: cublasSetPointerMode_v2 (
95
+ cublas_sys :: cublasCreate_v2 ( raw. as_mut_ptr ( ) ) . to_result ( ) ?;
96
+ cublas_sys :: cublasSetPointerMode_v2 (
92
97
raw. assume_init ( ) ,
93
- sys :: v2 :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
98
+ cublas_sys :: cublasPointerMode_t:: CUBLAS_POINTER_MODE_DEVICE ,
94
99
)
95
100
. to_result ( ) ?;
96
101
Ok ( Self {
@@ -107,7 +112,7 @@ impl CublasContext {
107
112
108
113
unsafe {
109
114
let inner = mem:: replace ( & mut ctx. raw , ptr:: null_mut ( ) ) ;
110
- match sys :: v2 :: cublasDestroy_v2 ( inner) . to_result ( ) {
115
+ match cublas_sys :: cublasDestroy_v2 ( inner) . to_result ( ) {
111
116
Ok ( ( ) ) => {
112
117
mem:: forget ( ctx) ;
113
118
Ok ( ( ) )
@@ -122,7 +127,7 @@ impl CublasContext {
122
127
let mut raw = MaybeUninit :: < u32 > :: uninit ( ) ;
123
128
unsafe {
124
129
// getVersion can't fail
125
- sys :: v2 :: cublasGetVersion_v2 ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
130
+ cublas_sys :: cublasGetVersion_v2 ( self . raw , raw. as_mut_ptr ( ) . cast ( ) )
126
131
. to_result ( )
127
132
. unwrap ( ) ;
128
133
@@ -140,17 +145,17 @@ impl CublasContext {
140
145
) -> Result < T > {
141
146
unsafe {
142
147
// cudaStream_t is the same as CUstream
143
- sys :: v2 :: cublasSetStream_v2 (
148
+ cublas_sys :: cublasSetStream_v2 (
144
149
self . raw ,
145
- mem:: transmute :: < * mut cust :: sys :: CUstream_st , * mut cublas_sys:: v2 :: CUstream_st > (
150
+ mem:: transmute :: < * mut driver_sys :: CUstream_st , * mut cublas_sys:: CUstream_st > (
146
151
stream. as_inner ( ) ,
147
152
) ,
148
153
)
149
154
. to_result ( ) ?;
150
155
let res = func ( self ) ?;
151
156
// reset the stream back to NULL just in case someone calls with_stream, then drops the stream, and tries to
152
157
// execute a raw sys function with the context's handle.
153
- sys :: v2 :: cublasSetStream_v2 ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
158
+ cublas_sys :: cublasSetStream_v2 ( self . raw , ptr:: null_mut ( ) ) . to_result ( ) ?;
154
159
Ok ( res)
155
160
}
156
161
}
@@ -180,12 +185,12 @@ impl CublasContext {
180
185
/// ```
181
186
pub fn set_atomics_mode ( & self , allowed : bool ) -> Result < ( ) > {
182
187
unsafe {
183
- Ok ( sys :: v2 :: cublasSetAtomicsMode (
188
+ Ok ( cublas_sys :: cublasSetAtomicsMode (
184
189
self . raw ,
185
190
if allowed {
186
- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
191
+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED
187
192
} else {
188
- sys :: v2 :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
193
+ cublas_sys :: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED
189
194
} ,
190
195
)
191
196
. to_result ( ) ?)
@@ -210,10 +215,11 @@ impl CublasContext {
210
215
pub fn get_atomics_mode ( & self ) -> Result < bool > {
211
216
let mut mode = MaybeUninit :: uninit ( ) ;
212
217
unsafe {
213
- sys :: v2 :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
218
+ cublas_sys :: cublasGetAtomicsMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
214
219
Ok ( match mode. assume_init ( ) {
215
- sys:: v2:: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
216
- sys:: v2:: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
220
+ cublas_sys:: cublasAtomicsMode_t:: CUBLAS_ATOMICS_ALLOWED => true ,
221
+ cublas_sys:: cublasAtomicsMode_t:: CUBLAS_ATOMICS_NOT_ALLOWED => false ,
222
+ _ => false ,
217
223
} )
218
224
}
219
225
}
@@ -233,9 +239,9 @@ impl CublasContext {
233
239
/// ```
234
240
pub fn set_math_mode ( & self , math_mode : MathMode ) -> Result < ( ) > {
235
241
unsafe {
236
- Ok ( sys :: v2 :: cublasSetMathMode (
242
+ Ok ( cublas_sys :: cublasSetMathMode (
237
243
self . raw ,
238
- mem:: transmute :: < u32 , cublas_sys:: v2 :: cublasMath_t > ( math_mode. bits ( ) ) ,
244
+ mem:: transmute :: < u32 , cublas_sys:: cublasMath_t > ( math_mode. bits ( ) ) ,
239
245
)
240
246
. to_result ( ) ?)
241
247
}
@@ -258,7 +264,7 @@ impl CublasContext {
258
264
pub fn get_math_mode ( & self ) -> Result < MathMode > {
259
265
let mut mode = MaybeUninit :: uninit ( ) ;
260
266
unsafe {
261
- sys :: v2 :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
267
+ cublas_sys :: cublasGetMathMode ( self . raw , mode. as_mut_ptr ( ) ) . to_result ( ) ?;
262
268
Ok ( MathMode :: from_bits ( mode. assume_init ( ) as u32 )
263
269
. expect ( "Invalid MathMode from cuBLAS" ) )
264
270
}
@@ -298,7 +304,7 @@ impl CublasContext {
298
304
let path = log_file_name. map ( |p| CString :: new ( p) . expect ( "nul in log_file_name" ) ) ;
299
305
let path_ptr = path. map_or ( ptr:: null ( ) , |s| s. as_ptr ( ) ) ;
300
306
301
- sys :: v2 :: cublasLoggerConfigure (
307
+ cublas_sys :: cublasLoggerConfigure (
302
308
enable as i32 ,
303
309
log_to_stdout as i32 ,
304
310
log_to_stderr as i32 ,
@@ -315,7 +321,7 @@ impl CublasContext {
315
321
///
316
322
/// The callback must not panic and unwind.
317
323
pub unsafe fn set_logger_callback ( callback : Option < unsafe extern "C" fn ( * const c_char ) > ) {
318
- sys :: v2 :: cublasSetLoggerCallback ( callback)
324
+ cublas_sys :: cublasSetLoggerCallback ( callback)
319
325
. to_result ( )
320
326
. unwrap ( ) ;
321
327
}
@@ -324,7 +330,7 @@ impl CublasContext {
324
330
pub fn get_logger_callback ( ) -> Option < unsafe extern "C" fn ( * const c_char ) > {
325
331
let mut cb = MaybeUninit :: uninit ( ) ;
326
332
unsafe {
327
- sys :: v2 :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
333
+ cublas_sys :: cublasGetLoggerCallback ( cb. as_mut_ptr ( ) )
328
334
. to_result ( )
329
335
. unwrap ( ) ;
330
336
cb. assume_init ( )
@@ -335,7 +341,7 @@ impl CublasContext {
335
341
impl Drop for CublasContext {
336
342
fn drop ( & mut self ) {
337
343
unsafe {
338
- sys :: v2 :: cublasDestroy_v2 ( self . raw ) ;
344
+ cublas_sys :: cublasDestroy_v2 ( self . raw ) ;
339
345
}
340
346
}
341
347
}
0 commit comments