forked from arrayfire/arrayfire-python
-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathutil.py
117 lines (94 loc) · 3.21 KB
/
util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
#######################################################
# Copyright (c) 2015, ArrayFire
# All rights reserved.
#
# This file is distributed under 3-clause BSD license.
# The complete license agreement can be obtained at:
# http://arrayfire.com/licenses/BSD-3-Clause
########################################################
from .library import *
import numbers
def dim4(d0=1, d1=1, d2=1, d3=1):
c_dim4 = ct.c_longlong * 4
out = c_dim4(1, 1, 1, 1)
for i, dim in enumerate((d0, d1, d2, d3)):
if (dim is not None): out[i] = dim
return out
def _is_number(a):
return isinstance(a, numbers.Number)
def number_dtype(a):
if isinstance(a, bool):
return Dtype.b8
if isinstance(a, int):
return Dtype.s64
elif isinstance(a, float):
return Dtype.f64
elif isinstance(a, complex):
return Dtype.c64
else:
return to_dtype[a.dtype.char]
def implicit_dtype(number, a_dtype):
n_dtype = number_dtype(number)
n_value = n_dtype.value
f64v = Dtype.f64.value
f32v = Dtype.f32.value
c32v = Dtype.c32.value
c64v = Dtype.c64.value
if n_value == f64v and (a_dtype == f32v or a_dtype == c32v):
return Dtype.f32
if n_value == c64v and (a_dtype == f32v or a_dtype == c32v):
return Dtype.c32
return n_dtype
def dim4_to_tuple(dims, default=1):
assert(isinstance(dims, tuple))
if (default is not None):
assert(_is_number(default))
out = [default]*4
for i, dim in enumerate(dims):
out[i] = dim
return tuple(out)
def to_str(c_str):
return str(c_str.value.decode('utf-8'))
def safe_call(af_error):
if (af_error != ERR.NONE.value):
err_str = ct.c_char_p(0)
err_len = ct.c_longlong(0)
backend.get().af_get_last_error(ct.pointer(err_str), ct.pointer(err_len))
raise RuntimeError(to_str(err_str), af_error)
def get_version():
major=ct.c_int(0)
minor=ct.c_int(0)
patch=ct.c_int(0)
safe_call(backend.get().af_get_version(ct.pointer(major), ct.pointer(minor), ct.pointer(patch)))
return major,minor,patch
typecodes = ['f', 'F', 'd', 'D', 'b', 'B', 'i', 'I', 'l', 'L']
to_dtype = {'f' : Dtype.f32,
'd' : Dtype.f64,
'b' : Dtype.b8,
'B' : Dtype.u8,
'i' : Dtype.s32,
'I' : Dtype.u32,
'l' : Dtype.s64,
'L' : Dtype.u64,
'F' : Dtype.c32,
'D' : Dtype.c64}
to_typecode = {Dtype.f32.value : 'f',
Dtype.f64.value : 'd',
Dtype.b8.value : 'b',
Dtype.u8.value : 'B',
Dtype.s32.value : 'i',
Dtype.u32.value : 'I',
Dtype.s64.value : 'l',
Dtype.u64.value : 'L',
Dtype.c32.value : 'F',
Dtype.c64.value : 'D'}
to_c_type = {Dtype.f32.value : ct.c_float,
Dtype.f64.value : ct.c_double,
Dtype.b8.value : ct.c_char,
Dtype.u8.value : ct.c_ubyte,
Dtype.s32.value : ct.c_int,
Dtype.u32.value : ct.c_uint,
Dtype.s64.value : ct.c_longlong,
Dtype.u64.value : ct.c_ulonglong,
Dtype.c32.value : ct.c_float * 2,
Dtype.c64.value : ct.c_double * 2}