Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Optimize model fit to avoid allocations in hot loops #28

Open
wants to merge 22 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion src/idw.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ end
IDW(exponent) = IDW(exponent, Euclidean())
IDW() = IDW(1)

struct IDWState{D<:AbstractGeoTable}
mutable struct IDWState{D<:AbstractGeoTable}
data::D
end

Expand All @@ -48,6 +48,13 @@ function fit(model::IDW, data)
FittedIDW(model, state)
end

function fit!(fitted::FittedIDW, newdata)
# update state data
fitted.state.data = newdata

nothing
end

#-----------------
# PREDICTION STEP
#-----------------
Expand Down
191 changes: 136 additions & 55 deletions src/krig.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,17 @@ A Kriging model (e.g. Simple Kriging, Ordinary Kriging).
abstract type KrigingModel <: GeoStatsModel end

"""
KrigingState(data, LHS, RHS, ncon)
KrigingState(data, LHS, RHS, FHS, nfun, miss)

A Kriging state stores information needed
to perform estimation at any given geometry.
"""
mutable struct KrigingState{D<:AbstractGeoTable,F,A}
mutable struct KrigingState{D<:AbstractGeoTable,L,R,F}
data::D
LHS::F
RHS::A
ncon::Int
LHS::L
RHS::R
FHS::F
nfun::Int
miss::Vector{Int}
end

Expand All @@ -44,51 +45,119 @@ struct FittedKriging{M<:KrigingModel,S<:KrigingState} <: FittedGeoStatsModel
state::S
end

status(fitted::FittedKriging) = _status(fitted.state.LHS)
status(fitted::FittedKriging) = _status(fitted.state.FHS)

_status(LHS) = issuccess(LHS)
_status(LHS::SVD) = true
_status(FHS) = issuccess(FHS)
_status(FHS::SVD) = true

#--------------
# FITTING STEP
#--------------

function fit(model::KrigingModel, data)
# initialize Kriging system
LHS, RHS, ncon, miss = initkrig(model, data)
# check compatibility of model and data
checkcompat(model, data)

# pre-allocate memory for LHS and RHS
LHS, RHS = prealloc(model, data)

# set LHS of Kriging system
nfun, miss = setlhs!(model, LHS, data)

# factorize LHS
FLHS = lhsfactorize(model, LHS)
VHS = @view LHS[begin:end, begin:end]
FHS = lhsfactorize(model, VHS)

# record Kriging state
state = KrigingState(data, FLHS, RHS, ncon, miss)
state = KrigingState(data, LHS, RHS, FHS, nfun, miss)

FittedKriging(model, state)
end

# initialize Kriging system
function initkrig(model::KrigingModel, data)
fun = model.fun
dom = domain(data)
tab = values(data)
function fit!(fitted::FittedKriging, newdata)
model = fitted.model
state = fitted.state

# retrieve matrix parameters
nobs = nelements(dom)
nvar = nvariates(fun)
# check compatibility of data size
checkdatasize(fitted, newdata)

# check compatibility of model and data
checkcompat(model, newdata)

# update state data
state.data = newdata

# set LHS of Kriging system
state.nfun, state.miss = setlhs!(model, state.LHS, newdata)

# number of modified rows
ncon = nconstraints(model)
nrow = nobs * nvar + ncon
nrow = state.nfun + ncon

# factorize LHS
VHS = @view state.LHS[1:nrow, 1:nrow]
state.FHS = lhsfactorize(model, VHS)

nothing
end

# make sure data is compatible with model
# make sure data is compatible with model
function checkcompat(model::KrigingModel, data)
fun = model.fun
nvar = nvariates(fun)
nfeat = ncol(data) - 1
if nfeat != nvar
throw(ArgumentError("$nfeat data column(s) provided to $nvar-variate Kriging model"))
end
end

# make sure data size is compatible
function checkdatasize(fitted::FittedKriging, data)
LHS = fitted.state.LHS
nlhs = size(LHS, 1)
nobs = nrow(data)
if nobs > nlhs
throw(ArgumentError("in-place fit called with $nobs data row(s) and $nlhs maximum size"))
end
end

# pre-allocate memory for LHS and RHS
function prealloc(model::KrigingModel, data)
fun = model.fun
dom = domain(data)

# retrieve matrix parameters
nobs = nelements(dom)
nvar = nvariates(fun)
ncon = nconstraints(model)
nfun = nobs * nvar
nrow = nfun + ncon

# pre-allocate memory for LHS
F = fun(dom[1], dom[1])
V = eltype(ustrip.(F))
LHS = Matrix{V}(undef, nrow, nrow)

# pre-allocate memory for RHS
RHS = similar(LHS, nrow, nvar)

LHS, RHS
end

# set LHS of Kriging system
function setlhs!(model::KrigingModel, LHS, data)
fun = model.fun
dom = domain(data)
tab = values(data)

# number of function evaluations
nobs = nelements(dom)
nvar = nvariates(fun)
nfun = nobs * nvar

# find locations with missing values
miss = missingindices(tab)

# set main block with pairwise evaluation
GeoStatsFunctions.pairwise!(LHS, fun, dom)

Expand All @@ -100,16 +169,10 @@ function initkrig(model::KrigingModel, data)
# set blocks of constraints
lhsconstraints!(model, LHS, dom)

# find locations with missing values
miss = missingindices(tab)

# knock out entries with missing values
lhsmissings!(LHS, ncon, miss)
lhsmissings!(LHS, nfun, miss)

# pre-allocate memory for RHS
RHS = similar(LHS, nrow, nvar)

LHS, RHS, ncon, miss
nfun, miss
end

# choose appropriate factorization of LHS
Expand Down Expand Up @@ -140,7 +203,7 @@ function lhsbanded!(LHS, fun, dom)

# retrieve matrix paramaters
nobs = nelements(dom)
nvar = size(S, 1)
nvar = nvariates(fun)
nfun = nobs * nvar

@inbounds for j in 1:nfun, i in 1:nfun
Expand Down Expand Up @@ -169,9 +232,7 @@ function missingindices(tab)
end

# knock out entries with missing values
function lhsmissings!(LHS, ncon, miss)
nrow = size(LHS, 1)
nfun = nrow - ncon
function lhsmissings!(LHS, nfun, miss)
@inbounds for j in miss, i in 1:nfun
LHS[i, j] = 0
end
Expand Down Expand Up @@ -236,11 +297,20 @@ end
⦿(λ, z::Missing) = 0

function predictvar(fitted::FittedKriging, weights::KrigingWeights, gₒ)
model = fitted.model
RHS = fitted.state.RHS
fun = fitted.model.fun
nfun = fitted.state.nfun
ncon = nconstraints(model)
nrow = nfun + ncon

# geostatistical function
fun = model.fun

# view valid rows of RHS
VHS = @view RHS[1:nrow, :]

# covariance formula for given function
Σ = krigvar(fun, weights, RHS, gₒ)
Σ = krigvar(fun, weights, VHS, gₒ)

# treat numerical issues
ϵ = eltype(Σ)(1e-10)
Expand Down Expand Up @@ -280,42 +350,53 @@ function wmul(weights::KrigingWeights, RHS)
Kλ, Kν
end

# solve Kriging system at target geometry
function weights(fitted::FittedKriging, gₒ)
LHS = fitted.state.LHS
data = fitted.state.data
FHS = fitted.state.FHS
RHS = fitted.state.RHS
ncon = fitted.state.ncon
miss = fitted.state.miss
dom = domain(fitted.state.data)
fun = fitted.model.fun
nfun = fitted.state.nfun
nrow = size(FHS, 1)

# retrieve domain of data
dom = domain(data)

# adjust CRS of gₒ
gₒ′ = gₒ |> Proj(crs(dom))

# set RHS of Kriging system
setrhs!(fitted, RHS, dom, gₒ′)

# solve Kriging system
W = FHS \ @view RHS[1:nrow, :]

# split weights and Lagrange multipliers
λ = @view W[begin:nfun, :]
ν = @view W[(nfun+1):end, :]

KrigingWeights(λ, ν)
end

# set RHS of Kriging system
function setrhs!(fitted::FittedKriging, RHS, dom, gₒ)
fun = fitted.model.fun
miss = fitted.state.miss

# set main blocks with pairwise evaluation
GeoStatsFunctions.pairwise!(RHS, fun, dom, [gₒ])
GeoStatsFunctions.pairwise!(RHS, fun, dom, [gₒ])

# adjustments for numerical stability
if isstationary(fun) && !isbanded(fun)
rhsbanded!(RHS, fun, dom)
end

# set blocks of constraints
rhsconstraints!(fitted, gₒ)
rhsconstraints!(fitted, gₒ)

# knock out entries with missing values
rhsmissings!(RHS, miss)

# solve Kriging system
W = LHS \ RHS

# index of first constraint
ind = size(LHS, 1) - ncon + 1

# split weights and Lagrange multipliers
λ = @view W[begin:(ind - 1), :]
ν = @view W[ind:end, :]

KrigingWeights(λ, ν)
nothing
end

# convert RHS into banded matrix
Expand All @@ -325,7 +406,7 @@ function rhsbanded!(RHS, fun, dom)

# retrieve matrix paramaters
nobs = nelements(dom)
nvar = size(S, 1)
nvar = nvariates(fun)
nfun = nobs * nvar

@inbounds for j in 1:nvar, i in 1:nfun
Expand Down
23 changes: 11 additions & 12 deletions src/krig/ordinary.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,17 +14,15 @@ end
nconstraints(model::OrdinaryKriging) = nvariates(model.fun)

function lhsconstraints!(model::OrdinaryKriging, LHS::AbstractMatrix, domain)
# number of variables
nobs = nelements(domain)
nvar = nvariates(model.fun)

# number of constraints
ncon = nconstraints(model)

# retrieve size of LHS
nrow, ncol = size(LHS)
nfun = nobs * nvar
nrow = nfun + ncon
ncol = nrow

# index of first constraint
ind = nrow - ncon + 1
ind = nfun + 1

# set identity blocks
@inbounds for j in 1:(ind - 1), i in ind:nrow
Expand All @@ -44,13 +42,14 @@ end

function rhsconstraints!(fitted::FittedKriging{<:OrdinaryKriging}, gₒ)
RHS = fitted.state.RHS
ncon = fitted.state.ncon

# retrieve size of RHS
nrow, ncol = size(RHS)
nvar = nvariates(fitted.model.fun)
ncon = nconstraints(fitted.model)
nfun = fitted.state.nfun
nrow = nfun + ncon
ncol = nvar

# index of first constraint
ind = nrow - ncon + 1
ind = nfun + 1

# set identity block
@inbounds for j in 1:ncol, i in ind:nrow
Expand Down
Loading
Loading