Skip to content

Commit 226dedd

Browse files
committed
Return MvNormal in multivariate Kriging
1 parent e3924e2 commit 226dedd

File tree

4 files changed

+62
-46
lines changed

4 files changed

+62
-46
lines changed

src/krig.jl

+15-8
Original file line numberDiff line numberDiff line change
@@ -178,14 +178,20 @@ predict(fitted::FittedKriging, vars, gₒ) = predictmean(fitted, weights(fitted,
178178

179179
predictprob(fitted::FittedKriging, var::AbstractString, gₒ) = predictprob(fitted, Symbol(var), gₒ)
180180

181-
predictprob(fitted::FittedKriging, var::Symbol, gₒ) = predictprob(fitted, (var,), gₒ) |> first
181+
function predictprob(fitted::FittedKriging, var::Symbol, gₒ)
182+
w = weights(fitted, gₒ)
183+
μ = predictmean(fitted, w, (var,)) |> first
184+
σ² = predictvar(fitted, w, gₒ) |> first
185+
# https://github.com/JuliaStats/Distributions.jl/issues/1413
186+
Normal(ustrip.(μ), σ²)
187+
end
182188

183189
function predictprob(fitted::FittedKriging, vars, gₒ)
184190
w = weights(fitted, gₒ)
185191
μ = predictmean(fitted, w, vars)
186-
σ² = predictvar(fitted, w, gₒ)
192+
Σ = predictvar(fitted, w, gₒ)
187193
# https://github.com/JuliaStats/Distributions.jl/issues/1413
188-
@. Normal(ustrip(μ), σ²)
194+
MvNormal(ustrip.(μ), Σ)
189195
end
190196

191197
predictmean(fitted::FittedKriging, weights::KrigingWeights, vars) = krigmean(fitted, weights, vars)
@@ -213,11 +219,12 @@ function predictvar(fitted::FittedKriging, weights::KrigingWeights, gₒ)
213219
RHS = fitted.state.RHS
214220
fun = fitted.model.fun
215221

216-
# variance formula for given function
217-
σ² = krigvar(fun, weights, RHS, gₒ)
222+
# covariance formula for given function
223+
Σ = krigvar(fun, weights, RHS, gₒ)
218224

219225
# treat numerical issues
220-
max.(zero(σ²), σ²)
226+
ϵ = eltype(Σ)(1e-10)
227+
Symmetric+ ϵ * I)
221228
end
222229

223230
function krigvar(fun::GeoStatsFunction, weights::KrigingWeights, RHS, gₒ)
@@ -230,7 +237,7 @@ function krigvar(fun::GeoStatsFunction, weights::KrigingWeights, RHS, gₒ)
230237
# compute cov(0) considering change of support
231238
Cₒ = ustrip.(covzero(fun, gₒ)) * I(k)
232239

233-
diag(Cₒ) - diag(Cλ) - diag(Cν)
240+
Cₒ - -
234241
end
235242

236243
function krigvar(t::Transiogram, weights::KrigingWeights, RHS, gₒ)
@@ -252,7 +259,7 @@ function krigvar(t::Transiogram, weights::KrigingWeights, RHS, gₒ)
252259
Tₒ = t(gₒ, gₒ)
253260
Cₒ = @inbounds [p[i] * (Tₒ[i, j] - p[j]) for i in 1:k, j in 1:k]
254261

255-
diag(Cₒ) - diag(Cλ) - diag(Cν)
262+
Cₒ - -
256263
end
257264

258265
function covzero(fun::GeoStatsFunction, gₒ)

src/models.jl

+5-2
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ function fitpredictneigh(model, gtb, dom, path, point, prob, minneighbors, maxne
133133
inds = traverse(dom, path)
134134

135135
# prediction function
136-
predfun = prob ? predictprob : predict
136+
predfun = prob ? _marginals predictprob : predict
137137

138138
# predict variables
139139
cols = Tables.columns(values(gtb))
@@ -177,7 +177,7 @@ function fitpredictfull(model, gtb, dom, path, point, prob)
177177
inds = traverse(dom, path)
178178

179179
# prediction function
180-
predfun = prob ? predictprob : predict
180+
predfun = prob ? _marginals predictprob : predict
181181

182182
# fit model to data
183183
fmodel = fit(model, gtb)
@@ -197,6 +197,9 @@ function fitpredictfull(model, gtb, dom, path, point, prob)
197197
georef(predtab, dom)
198198
end
199199

200+
_marginals(dist::UnivariateDistribution) = (dist,)
201+
_marginals(dist::MvNormal) = Normal.(mean(dist), var(dist))
202+
200203
# ----------------
201204
# IMPLEMENTATIONS
202205
# ----------------

test/krig.jl

+32-36
Original file line numberDiff line numberDiff line change
@@ -331,14 +331,14 @@
331331
okdist = GeoStatsModels.predictprob(ok, (:a, :b, :c), pset[i])
332332
ukdist = GeoStatsModels.predictprob(uk, (:a, :b, :c), pset[i])
333333
dkdist = GeoStatsModels.predictprob(dk, (:a, :b, :c), pset[i])
334-
@test mean.(skdist) [j == i for j in 1:3]
335-
@test mean.(okdist) [j == i for j in 1:3]
336-
@test mean.(ukdist) [j == i for j in 1:3]
337-
@test mean.(dkdist) [j == i for j in 1:3]
338-
@test isapprox(var.(skdist), [0.0, 0.0, 0.0], atol=1e-10)
339-
@test isapprox(var.(okdist), [0.0, 0.0, 0.0], atol=1e-10)
340-
@test isapprox(var.(ukdist), [0.0, 0.0, 0.0], atol=1e-10)
341-
@test isapprox(var.(dkdist), [0.0, 0.0, 0.0], atol=1e-10)
334+
@test mean(skdist) [j == i for j in 1:3]
335+
@test mean(okdist) [j == i for j in 1:3]
336+
@test mean(ukdist) [j == i for j in 1:3]
337+
@test mean(dkdist) [j == i for j in 1:3]
338+
@test isapprox(var(skdist), [0.0, 0.0, 0.0], atol=1e-8)
339+
@test isapprox(var(okdist), [0.0, 0.0, 0.0], atol=1e-8)
340+
@test isapprox(var(ukdist), [0.0, 0.0, 0.0], atol=1e-8)
341+
@test isapprox(var(dkdist), [0.0, 0.0, 0.0], atol=1e-8)
342342
end
343343

344344
# predict on a specific point
@@ -347,14 +347,14 @@
347347
okdist = GeoStatsModels.predictprob(ok, (:a, :b, :c), pₒ)
348348
ukdist = GeoStatsModels.predictprob(uk, (:a, :b, :c), pₒ)
349349
dkdist = GeoStatsModels.predictprob(dk, (:a, :b, :c), pₒ)
350-
@test all-> 0 μ 1, mean.(skdist))
351-
@test all-> 0 μ 1, mean.(okdist))
352-
@test all-> 0 μ 1, mean.(ukdist))
353-
@test all-> 0 μ 1, mean.(dkdist))
354-
@test all((0), var.(skdist))
355-
@test all((0), var.(okdist))
356-
@test all((0), var.(ukdist))
357-
@test all((0), var.(dkdist))
350+
@test all-> 0 μ 1, mean(skdist))
351+
@test all-> 0 μ 1, mean(okdist))
352+
@test all-> 0 μ 1, mean(ukdist))
353+
@test all-> 0 μ 1, mean(dkdist))
354+
@test all((0), var(skdist))
355+
@test all((0), var(okdist))
356+
@test all((0), var(ukdist))
357+
@test all((0), var(dkdist))
358358
end
359359

360360
@testset "Transiogram" begin
@@ -369,30 +369,26 @@
369369

370370
# interpolation property
371371
for i in 1:3
372-
skdist = GeoStatsModels.predictprob(sk, (:a, :b, :c), pset[i])
373-
okdist = GeoStatsModels.predictprob(ok, (:a, :b, :c), pset[i])
374-
ukdist = GeoStatsModels.predictprob(uk, (:a, :b, :c), pset[i])
375-
dkdist = GeoStatsModels.predictprob(dk, (:a, :b, :c), pset[i])
376-
@test mean.(skdist) [j == i for j in 1:3]
377-
@test mean.(okdist) [j == i for j in 1:3]
378-
@test mean.(ukdist) [j == i for j in 1:3]
379-
@test mean.(dkdist) [j == i for j in 1:3]
372+
skmean = GeoStatsModels.predict(sk, (:a, :b, :c), pset[i])
373+
okmean = GeoStatsModels.predict(ok, (:a, :b, :c), pset[i])
374+
ukmean = GeoStatsModels.predict(uk, (:a, :b, :c), pset[i])
375+
dkmean = GeoStatsModels.predict(dk, (:a, :b, :c), pset[i])
376+
@test skmean [j == i for j in 1:3]
377+
@test okmean [j == i for j in 1:3]
378+
@test ukmean [j == i for j in 1:3]
379+
@test dkmean [j == i for j in 1:3]
380380
end
381381

382382
# predict on a specific point
383383
pₒ = Point(50.0, 50.0)
384-
skdist = GeoStatsModels.predictprob(sk, (:a, :b, :c), pₒ)
385-
okdist = GeoStatsModels.predictprob(ok, (:a, :b, :c), pₒ)
386-
ukdist = GeoStatsModels.predictprob(uk, (:a, :b, :c), pₒ)
387-
dkdist = GeoStatsModels.predictprob(dk, (:a, :b, :c), pₒ)
388-
@test all-> 0 μ 1, mean.(skdist))
389-
@test all-> 0 μ 1, mean.(okdist))
390-
@test all-> 0 μ 1, mean.(ukdist))
391-
@test all-> 0 μ 1, mean.(dkdist))
392-
@test all((0), var.(skdist))
393-
@test all((0), var.(okdist))
394-
@test all((0), var.(ukdist))
395-
@test all((0), var.(dkdist))
384+
skmean = GeoStatsModels.predict(sk, (:a, :b, :c), pₒ)
385+
okmean = GeoStatsModels.predict(ok, (:a, :b, :c), pₒ)
386+
ukmean = GeoStatsModels.predict(uk, (:a, :b, :c), pₒ)
387+
dkmean = GeoStatsModels.predict(dk, (:a, :b, :c), pₒ)
388+
@test all-> 0 μ 1, skmean)
389+
@test all-> 0 μ 1, okmean)
390+
@test all-> 0 μ 1, ukmean)
391+
@test all-> 0 μ 1, dkmean)
396392
end
397393

398394
@testset "Fallbacks" begin

test/misc.jl

+10
Original file line numberDiff line numberDiff line change
@@ -28,4 +28,14 @@
2828
@test isapprox(pred.c[inds[25, 25]], 0.0, atol=1e-3)
2929
@test isapprox(pred.c[inds[50, 75]], 0.0, atol=1e-3)
3030
@test isapprox(pred.c[inds[75, 50]], 1.0, atol=1e-3)
31+
pred = GeoStatsModels.fitpredict(model, gtb, grid, maxneighbors=3, prob=true)
32+
@test isapprox(mean(pred.a[inds[25, 25]]), 1.0, atol=1e-3)
33+
@test isapprox(mean(pred.a[inds[50, 75]]), 0.0, atol=1e-3)
34+
@test isapprox(mean(pred.a[inds[75, 50]]), 0.0, atol=1e-3)
35+
@test isapprox(mean(pred.b[inds[25, 25]]), 0.0, atol=1e-3)
36+
@test isapprox(mean(pred.b[inds[50, 75]]), 1.0, atol=1e-3)
37+
@test isapprox(mean(pred.b[inds[75, 50]]), 0.0, atol=1e-3)
38+
@test isapprox(mean(pred.c[inds[25, 25]]), 0.0, atol=1e-3)
39+
@test isapprox(mean(pred.c[inds[50, 75]]), 0.0, atol=1e-3)
40+
@test isapprox(mean(pred.c[inds[75, 50]]), 1.0, atol=1e-3)
3141
end

0 commit comments

Comments
 (0)