Skip to content

Commit 65e6df4

Browse files
authored
Refactor fallbacks (#27)
1 parent 37f1e06 commit 65e6df4

File tree

7 files changed

+24
-14
lines changed

7 files changed

+24
-14
lines changed

src/models.jl

+14-2
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,13 @@ Predict one or multiple variables `vars` at geometry `gₒ` with
3333
given geostatistical `model`.
3434
"""
3535
predict(model::FittedGeoStatsModel, var::AbstractString, gₒ) = predict(model, Symbol(var), gₒ)
36-
predict(model::FittedGeoStatsModel, vars, gₒ) = [predict(model, var, gₒ) for var in vars]
36+
function predict(model::FittedGeoStatsModel, vars, gₒ)
37+
if length(vars) > 1
38+
throw(ArgumentError("cannot use univariate model to predict multiple variables"))
39+
else
40+
[predict(model, first(vars), gₒ)]
41+
end
42+
end
3743

3844
"""
3945
predictprob(model, vars, gₒ)
@@ -42,7 +48,13 @@ Predict distribution of one or multiple variables `vars` at
4248
geometry `gₒ` with given geostatistical `model`.
4349
"""
4450
predictprob(model::FittedGeoStatsModel, var::AbstractString, gₒ) = predictprob(model, Symbol(var), gₒ)
45-
predictprob(model::FittedGeoStatsModel, vars, gₒ) = [predictprob(model, var, gₒ) for var in vars]
51+
function predictprob(model::FittedGeoStatsModel, vars, gₒ)
52+
if length(vars) > 1
53+
throw(ArgumentError("cannot use univariate model to predict multiple variables"))
54+
else
55+
[predictprob(model, first(vars), gₒ)]
56+
end
57+
end
4658

4759
"""
4860
status(fitted)

test/idw.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
@test pred isa Composition
2424
end
2525

26-
@testset "Single/Multiple" begin
26+
@testset "Fallbacks" begin
2727
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
2828
idw = GeoStatsModels.fit(IDW(), d)
2929
pred1 = GeoStatsModels.predict(idw, :z, Point(0.0, 0.0))

test/krig.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -395,7 +395,7 @@
395395
@test all((0), var.(dkdist))
396396
end
397397

398-
@testset "Single/Multiple" begin
398+
@testset "Fallbacks" begin
399399
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
400400
γ = GaussianVariogram()
401401
ok = GeoStatsModels.fit(OK(γ), d)

test/lwr.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
@test pred isa Composition
2424
end
2525

26-
@testset "Single/Multiple" begin
26+
@testset "Fallbacks" begin
2727
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
2828
lwr = GeoStatsModels.fit(LWR(), d)
2929
pred1 = GeoStatsModels.predict(lwr, :z, Point(0.0, 0.0))

test/misc.jl

+2-3
Original file line numberDiff line numberDiff line change
@@ -3,10 +3,9 @@
33

44
# fitpredict with IDW
55
pset = PointSet(rand(rng, Point, 3))
6-
gtb = georef((a=[1, 2, 3], b=[4, 5, 6]), pset)
6+
gtb = georef((z=[1, 2, 3],), pset)
77
pred = GeoStatsModels.fitpredict(IDW(), gtb, pset, neighbors=false)
8-
@test pred.a == gtb.a
9-
@test pred.b == gtb.b
8+
@test pred.z == gtb.z
109
@test pred.geometry == gtb.geometry
1110

1211
# also works with views

test/nn.jl

+1-1
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@
3434
@test pred isa Composition
3535
end
3636

37-
@testset "Single/Multiple" begin
37+
@testset "Fallbacks" begin
3838
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
3939
nn = GeoStatsModels.fit(NN(), d)
4040
pred1 = GeoStatsModels.predict(nn, :z, Point(0.0, 0.0))

test/poly.jl

+4-5
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,11 @@
3636

3737
# correct schema
3838
rng = StableRNG(42)
39-
d = georef((a=rand(rng, 10), b=rand(rng, 10)), rand(rng, Point, 10))
39+
d = georef((z=rand(rng, 10),), rand(rng, Point, 10))
4040
= fitpredict(Polynomial(), d)
4141
= values(d̄)
42-
@test propertynames(t̄) == (:a, :b)
43-
@test eltype(t̄.a) == Float64
44-
@test eltype(t̄.b) == Float64
42+
@test propertynames(t̄) == (:z,)
43+
@test eltype(t̄.z) == Float64
4544

4645
# latlon coordinates
4746
d = georef((; z=[1, 2, 3]), Point.([LatLon(0, 0), LatLon(0, 1), LatLon(1, 0)]))
@@ -68,7 +67,7 @@
6867
@test pred isa Composition
6968
end
7069

71-
@testset "Single/Multiple" begin
70+
@testset "Fallbacks" begin
7271
d = georef((; z=[1.0, 0.0, 1.0]), [(0.0, 0.0), (1.0, 0.0), (1.0, 1.0)])
7372
poly = GeoStatsModels.fit(Polynomial(), d)
7473
pred1 = GeoStatsModels.predict(poly, :z, Point(0.0, 0.0))

0 commit comments

Comments
 (0)