Skip to content

Commit 179a3ce

Browse files
committed
Use fit! in fitpredict
1 parent f8f5205 commit 179a3ce

File tree

1 file changed

+12
-7
lines changed

1 file changed

+12
-7
lines changed

src/models.jl

+12-7
Original file line numberDiff line numberDiff line change
@@ -127,7 +127,12 @@ function fitpredictneigh(model, gtb, dom, path, point, prob, minneighbors, maxne
127127
end
128128

129129
# pre-allocate memory for neighbors
130-
neighbors = Vector{Int}(undef, maxneighbors)
130+
neighbors = collect(1:maxneighbors)
131+
132+
# initialize fitted model with first samples
133+
ninds = view(neighbors, 1:maxneighbors)
134+
ndata = view(gtb, ninds)
135+
fmodel = fit(model, ndata)
131136

132137
# traverse domain with given path
133138
inds = traverse(dom, path)
@@ -143,18 +148,18 @@ function fitpredictneigh(model, gtb, dom, path, point, prob, minneighbors, maxne
143148
center = centroid(dom, ind)
144149

145150
# find neighbors with data
146-
nneigh = search!(neighbors, center, searcher)
151+
n = search!(neighbors, center, searcher)
147152

148153
# predict if enough neighbors
149-
if nneigh minneighbors
154+
if n minneighbors
150155
# final set of neighbors
151-
ninds = view(neighbors, 1:nneigh)
156+
ninds = view(neighbors, 1:n)
152157

153158
# view neighborhood with data
154-
samples = view(gtb, ninds)
159+
ndata = view(gtb, ninds)
155160

156-
# fit model to samples
157-
fmodel = fit(model, samples)
161+
# fit model to neighborhood
162+
fit!(fmodel, ndata)
158163

159164
# save prediction
160165
geom = point ? center : dom[ind]

0 commit comments

Comments
 (0)