Skip to content

Commit

Permalink
Correctly the return type of TimedArray functions in generated code
Browse files Browse the repository at this point in the history
  • Loading branch information
mstimberg committed Oct 30, 2023
1 parent b476451 commit 285dffb
Showing 1 changed file with 13 additions and 6 deletions.
19 changes: 13 additions & 6 deletions brian2/input/timedarray.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

import numpy as np

from brian2.codegen.generators import c_data_type
from brian2.core.clocks import defaultclock
from brian2.core.functions import Function
from brian2.core.names import Nameable
Expand Down Expand Up @@ -45,7 +46,7 @@ def cpp_impl(owner):
K = _find_K(owner.clock.dt_, dt)
code = (
"""
static inline double %NAME%(const double t)
static inline %TYPE% %NAME%(const double t)
{
const double epsilon = %DT% / %K%;
int i = (int)((t/epsilon + 0.5)/%K%);
Expand All @@ -61,6 +62,7 @@ def cpp_impl(owner):
.replace("%DT%", f"{dt:.18f}")
.replace("%K%", str(K))
.replace("%NUM_VALUES%", str(len(values)))
.replace("%TYPE%", c_data_type(values.dtype))
)

return code
Expand All @@ -72,7 +74,7 @@ def _generate_cpp_code_2d(values, dt, name):
def cpp_impl(owner):
K = _find_K(owner.clock.dt_, dt)
support_code = """
static inline double %NAME%(const double t, const int i)
static inline %TYPE% %NAME%(const double t, const int i)
{
const double epsilon = %DT% / %K%;
if (i < 0 || i >= %COLS%)
Expand All @@ -93,6 +95,7 @@ def cpp_impl(owner):
"%K%": str(K),
"%COLS%": str(values.shape[1]),
"%ROWS%": str(values.shape[0]),
"%TYPE%": c_data_type(values.dtype),
},
)
return code
Expand All @@ -105,7 +108,7 @@ def cython_impl(owner):
K = _find_K(owner.clock.dt_, dt)
code = (
"""
cdef double %NAME%(const double t):
cdef %TYPE% %NAME%(const double t):
global _namespace%NAME%_values
cdef double epsilon = %DT% / %K%
cdef int i = (int)((t/epsilon + 0.5)/%K%)
Expand All @@ -120,6 +123,7 @@ def cython_impl(owner):
.replace("%DT%", f"{dt:.18f}")
.replace("%K%", str(K))
.replace("%NUM_VALUES%", str(len(values)))
.replace("%TYPE%", c_data_type(values.dtype))
)

return code
Expand All @@ -131,7 +135,7 @@ def _generate_cython_code_2d(values, dt, name):
def cython_impl(owner):
K = _find_K(owner.clock.dt_, dt)
code = """
cdef double %NAME%(const double t, const int i):
cdef %TYPE% %NAME%(const double t, const int i):
global _namespace%NAME%_values
cdef double epsilon = %DT% / %K%
if i < 0 or i >= %COLS%:
Expand All @@ -151,6 +155,7 @@ def cython_impl(owner):
"%K%": str(K),
"%COLS%": str(values.shape[1]),
"%ROWS%": str(values.shape[0]),
"%TYPE%": c_data_type(values.dtype),
},
)
return code
Expand Down Expand Up @@ -236,7 +241,7 @@ def __init__(self, values, dt, name=None):
dimensions = get_dimensions(values)
self.dim = dimensions
values = np.asarray(values) # infer dtype
if values.dtype == np.object:
if values.dtype == object:
raise TypeError("TimedArray does not support arrays with dtype 'object'")
elif (
values.dtype == np.float64 and prefs.core.default_float_dtype != np.float64
Expand Down Expand Up @@ -347,7 +352,9 @@ def unitless_timed_array_func(t, i):
self.implementations.add_dynamic_implementation(
"numpy", create_numpy_implementation
)
values_flat = self.values.astype(np.double, order="C", copy=False).ravel()
values_flat = self.values.astype(
self.values.dtype, order="C", copy=False
).ravel()
namespace = lambda owner: {f"{self.name}_values": values_flat}

for target, (_, func_2d) in TimedArray.implementations.items():
Expand Down

0 comments on commit 285dffb

Please sign in to comment.