Skip to content

Commit 3d07901

Browse files
parameter-aware function wrappers
1 parent c846501 commit 3d07901

File tree

1 file changed

+18
-15
lines changed

1 file changed

+18
-15
lines changed

src/function_wrappers.jl

+18-15
Original file line numberDiff line numberDiff line change
@@ -1,42 +1,45 @@
1-
mutable struct TimeGradientWrapper{fType,uType} <: Function
1+
mutable struct TimeGradientWrapper{fType,uType,P} <: Function
22
f::fType
33
uprev::uType
4+
p::P
45
end
5-
(p::TimeGradientWrapper)(t) = (du2 = similar(p.uprev); p.f(t,p.uprev,du2); du2)
6-
(p::TimeGradientWrapper)(du2,t) = p.f(t,p.uprev,du2)
6+
(ff::TimeGradientWrapper)(t) = (du2 = similar(ff.uprev); ff.f(du2,ff.uprev,ff.p,t); du2)
7+
(ff::TimeGradientWrapper)(du2,t) = ff.f(du2,ff.uprev,ff.p,t)
78

8-
mutable struct UJacobianWrapper{fType,tType} <: Function
9+
mutable struct UJacobianWrapper{fType,tType,P} <: Function
910
f::fType
1011
t::tType
12+
p::P
1113
end
1214

13-
(p::UJacobianWrapper)(du1,uprev) = p.f(p.t,uprev,du1)
14-
(p::UJacobianWrapper)(uprev) = (du1 = similar(uprev); p.f(p.t,uprev,du1); du1)
15+
(ff::UJacobianWrapper)(du1,uprev) = ff.f(du1,uprev,ff.p,ff.t)
16+
(ff::UJacobianWrapper)(uprev) = (du1 = similar(uprev); ff.f(du1,uprev,ff.p,ff.t); du1)
1517

16-
mutable struct TimeDerivativeWrapper{F,uType} <: Function
18+
mutable struct TimeDerivativeWrapper{F,uType,P} <: Function
1719
f::F
1820
u::uType
21+
p::P
1922
end
20-
(p::TimeDerivativeWrapper)(t) = p.f(t,p.u)
23+
(ff::TimeDerivativeWrapper)(t) = ff.f(ff.u,ff.p,t)
2124

22-
mutable struct UDerivativeWrapper{F,tType} <: Function
25+
mutable struct UDerivativeWrapper{F,tType,P} <: Function
2326
f::F
2427
t::tType
28+
p::P
2529
end
26-
(p::UDerivativeWrapper)(u) = p.f(p.t,u)
30+
(ff::UDerivativeWrapper)(u) = ff.f(u,ff.p,ff.t)
2731

2832
mutable struct ParamJacobianWrapper{fType,tType,uType} <: Function
2933
f::fType
3034
t::tType
3135
u::uType
3236
end
3337

34-
function (pf::ParamJacobianWrapper)(du1,p)
35-
pf.f(pf.t,pf.u,p,du1)
38+
function (ff::ParamJacobianWrapper)(du1,p)
39+
ff.f(du1,ff.u,p,ff.t)
3640
end
3741

38-
function (pf::ParamJacobianWrapper)(p)
42+
function (ff::ParamJacobianWrapper)(p)
3943
du1 = similar(uprev)
40-
set_param_values!(pf.f,p)
41-
pf.f(pf.t,pf.u,du1)
44+
ff.f(du1,ff.u,p,ff.t)
4245
end

0 commit comments

Comments
 (0)