Skip to content

Commit 37f1e06

Browse files
authored
Make sure data is compatible with Kriging model (#26)
* Make sure data is compatible with Kriging model * Refactor predict e predictprob for Kriging
1 parent e54eb11 commit 37f1e06

File tree

3 files changed

+17
-51
lines changed

3 files changed

+17
-51
lines changed

src/krig.jl

+14-21
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,12 @@ function initkrig(model::KrigingModel, data)
7474
ncon = nconstraints(model, nvar)
7575
nrow = nobs * nvar + ncon
7676

77+
# make sure data is compatible with model
78+
nfeat = ncol(data) - 1
79+
if nfeat != nvar
80+
throw(ArgumentError("$nfeat data column(s) provided to $nvar-variate Kriging model"))
81+
end
82+
7783
# pre-allocate memory for LHS
7884
LHS = Matrix{V}(undef, nrow, nrow)
7985

@@ -168,10 +174,14 @@ end
168174

169175
predict(fitted::FittedKriging, var::AbstractString, gₒ) = predict(fitted, Symbol(var), gₒ)
170176

177+
predict(fitted::FittedKriging, var::Symbol, gₒ) = predict(fitted, (var,), gₒ) |> first
178+
171179
predict(fitted::FittedKriging, vars, gₒ) = predictmean(fitted, weights(fitted, gₒ), vars)
172180

173181
predictprob(fitted::FittedKriging, var::AbstractString, gₒ) = predictprob(fitted, Symbol(var), gₒ)
174182

183+
predictprob(fitted::FittedKriging, var::Symbol, gₒ) = predictprob(fitted, (var,), gₒ) |> first
184+
175185
function predictprob(fitted::FittedKriging, vars, gₒ)
176186
w = weights(fitted, gₒ)
177187
μ = predictmean(fitted, w, vars)
@@ -182,29 +192,15 @@ end
182192

183193
predictmean(fitted::FittedKriging, weights::KrigingWeights, vars) = krigmean(fitted, weights, vars)
184194

185-
predictmean(fitted::FittedKriging, weights::KrigingWeights, var::Symbol) = first(predictmean(fitted, weights, (var,)))
186-
187195
function krigmean(fitted::FittedKriging, weights::KrigingWeights, vars)
188196
d = fitted.state.data
189197
λ = weights.λ
190198
k = size(λ, 2)
191-
n = length(vars)
192-
193-
@assert (k == n || k == 1) "invalid number of variables for Kriging model"
194199

195200
cols = Tables.columns(values(d))
196-
197-
if k == n
198-
@inbounds map(1:k) do j
199-
sum(1:n) do p
200-
λₚ = @view λ[p:k:end, j]
201-
zₚ = Tables.getcolumn(cols, vars[p])
202-
sum(i -> λₚ[i] ⦿ zₚ[i], eachindex(λₚ, zₚ))
203-
end
204-
end
205-
else # k == 1
206-
@inbounds map(1:n) do p
207-
λₚ = @view λ[:, 1]
201+
@inbounds map(1:k) do j
202+
sum(1:k) do p
203+
λₚ = @view λ[p:k:end, j]
208204
zₚ = Tables.getcolumn(cols, vars[p])
209205
sum(i -> λₚ[i] ⦿ zₚ[i], eachindex(λₚ, zₚ))
210206
end
@@ -223,10 +219,7 @@ function predictvar(fitted::FittedKriging, weights::KrigingWeights, gₒ)
223219
σ² = krigvar(fun, weights, RHS, gₒ)
224220

225221
# treat numerical issues
226-
σ²₊ = max.(zero(σ²), σ²)
227-
228-
# treat scalar case
229-
length(σ²₊) == 1 ? first(σ²₊) : σ²₊
222+
max.(zero(σ²), σ²)
230223
end
231224

232225
krigvar(fun::Variogram, weights::KrigingWeights, RHS, gₒ) = covvar(fun, weights, RHS, gₒ)

src/krig/simple.jl

+3-15
Original file line numberDiff line numberDiff line change
@@ -37,23 +37,11 @@ function krigmean(fitted::FittedKriging{<:SimpleKriging}, weights::KrigingWeight
3737
μ = fitted.model.mean
3838
λ = weights.λ
3939
k = size(λ, 2)
40-
n = length(vars)
41-
42-
@assert (k == n || k == 1) "invalid number of variables for Kriging model"
4340

4441
cols = Tables.columns(values(d))
45-
46-
if k == n
47-
@inbounds map(1:k) do j
48-
sum(1:n) do p
49-
λₚ = @view λ[p:k:end, j]
50-
zₚ = Tables.getcolumn(cols, vars[p])
51-
μ[p] + sum(i -> λₚ[i] ⦿ (zₚ[i] - μ[p]), eachindex(λₚ, zₚ))
52-
end
53-
end
54-
else # k == 1
55-
@inbounds map(1:n) do p
56-
λₚ = @view λ[:, 1]
42+
@inbounds map(1:k) do j
43+
sum(1:k) do p
44+
λₚ = @view λ[p:k:end, j]
5745
zₚ = Tables.getcolumn(cols, vars[p])
5846
μ[p] + sum(i -> λₚ[i] ⦿ (zₚ[i] - μ[p]), eachindex(λₚ, zₚ))
5947
end

test/misc.jl

-15
Original file line numberDiff line numberDiff line change
@@ -14,21 +14,6 @@
1414
vpred = GeoStatsModels.fitpredict(IDW(), vgtb, pset, neighbors=false)
1515
@test vpred == pred
1616

17-
# fitpredict with multiple variables and Kriging
18-
gtb = georef((; a=[1.0, 0.0, 0.0], b=[0.0, 1.0, 0.0], c=[0.0, 0.0, 1.0]), [(25.0, 25.0), (50.0, 75.0), (75.0, 50.0)])
19-
grid = CartesianGrid((100, 100), (0.5, 0.5), (1.0, 1.0))
20-
pred = GeoStatsModels.fitpredict(Kriging(SphericalVariogram(range=35.0)), gtb, grid, maxneighbors=3)
21-
inds = LinearIndices(size(grid))
22-
@test isapprox(pred.a[inds[25, 25]], 1.0, atol=1e-3)
23-
@test isapprox(pred.a[inds[50, 75]], 0.0, atol=1e-3)
24-
@test isapprox(pred.a[inds[75, 50]], 0.0, atol=1e-3)
25-
@test isapprox(pred.b[inds[25, 25]], 0.0, atol=1e-3)
26-
@test isapprox(pred.b[inds[50, 75]], 1.0, atol=1e-3)
27-
@test isapprox(pred.b[inds[75, 50]], 0.0, atol=1e-3)
28-
@test isapprox(pred.c[inds[25, 25]], 0.0, atol=1e-3)
29-
@test isapprox(pred.c[inds[50, 75]], 0.0, atol=1e-3)
30-
@test isapprox(pred.c[inds[75, 50]], 1.0, atol=1e-3)
31-
3217
# fitpredict with multiple variables and CoKriging
3318
gtb = georef((; a=[1.0, 0.0, 0.0], b=[0.0, 1.0, 0.0], c=[0.0, 0.0, 1.0]), [(25.0, 25.0), (50.0, 75.0), (75.0, 50.0)])
3419
grid = CartesianGrid((100, 100), (0.5, 0.5), (1.0, 1.0))

0 commit comments

Comments
 (0)