Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 39 additions & 0 deletions src/SamplingReduction.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,37 @@ function reduce_samples(ps::AbstractArray, rngs, t::Union{TypeS,TypeUnion}, ss::
return reduce(vcat, v)
end

function reduce_samples_hypergeometric(ps::AbstractArray, rngs, t::Union{TypeS,TypeUnion}, ss::AbstractArray...)
nt = length(ss)
v = Vector{Vector{get_type_rs(t, ss...)}}(undef, nt)
n = minimum(length.(ss))

# For hypergeometric sampling, we need to sample without replacement from finite populations
# The number of samples from each reservoir depends on hypergeometric distribution
# Total population size is sum of all reservoir sizes
total_pop = sum(length.(ss))

# Sample using hypergeometric distribution for each reservoir
ns = Vector{Int}(undef, nt)
remaining = n
remaining_pop = total_pop

for i in 1:(nt-1)
pop_i = length(ss[i])
# Use hypergeometric distribution: drawing `remaining` items from population `remaining_pop`
# where `pop_i` items are of the type we want
ns[i] = rand(extract_rng(rngs, 1), Hypergeometric(pop_i, remaining_pop - pop_i, remaining))
remaining -= ns[i]
remaining_pop -= pop_i
end
ns[nt] = remaining # Remainder goes to last reservoir

Threads.@threads for i in 1:nt
v[i] = sample(extract_rng(rngs, i), ss[i], ns[i]; replace = false)
end
return reduce(vcat, v)
end

extract_rng(v::AbstractArray, i) = v[i]
extract_rng(v::AbstractRNG, i) = v

Expand All @@ -31,6 +62,14 @@ function get_ps(ss::MultiAlgWRSWRSKIPSampler...)
sum_w = sum(getfield(s, :state) for s in ss)
return [s.state/sum_w for s in ss]
end
function get_ps(ss::MultiAlgRSampler...)
sum_w = sum(getfield(s, :seen_k) for s in ss)
return [s.seen_k/sum_w for s in ss]
end
function get_ps(ss::MultiAlgLSampler...)
sum_w = sum(getfield(s, :seen_k) for s in ss)
return [s.seen_k/sum_w for s in ss]
end

get_type_rs(::TypeS, s1::T, ss::T...) where {T} = eltype(s1)
function get_type_rs(::TypeUnion, s1::T, ss::T...) where {T}
Expand Down
47 changes: 37 additions & 10 deletions src/UnweightedSamplingMulti.jl
Original file line number Diff line number Diff line change
Expand Up @@ -209,11 +209,19 @@ end
is_ordered(s::MultiOrdAlgRSWRSKIPSampler) = true
is_ordered(s::MultiAlgRSWRSKIPSampler) = false

function Base.merge(ss::MultiAlgRSampler...)
error("To Be Implemented")
end
function Base.merge(ss::MultiAlgLSampler...)
error("To Be Implemented")
function Base.merge(ss::MultiAlgRSampler...)
newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...)
seen_k = sum(getfield(s, :seen_k) for s in ss)
n = minimum(s.n for s in ss)
return MultiAlgRSampler_Mut(n, seen_k, ss[1].rng, newvalue, nothing)
end
function Base.merge(ss::MultiAlgLSampler...)
newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...)
seen_k = sum(getfield(s, :seen_k) for s in ss)
# For AlgL, we need to initialize state and skip_k appropriately
# state should be 0.0 for new merged sampler, skip_k should be 0
n = minimum(s.n for s in ss)
return MultiAlgLSampler_Mut(n, 0.0, 0, seen_k, ss[1].rng, newvalue, nothing)
end
function Base.merge(ss::MultiAlgRSWRSKIPSampler...)
newvalue = reduce_samples(get_ps(ss...), [s.rng for s in ss], TypeUnion(), value.(ss)...)
Expand All @@ -223,11 +231,30 @@ function Base.merge(ss::MultiAlgRSWRSKIPSampler...)
return MultiAlgRSWRSKIPSampler_Mut(n, skip_k, seen_k, ss[1].rng, newvalue, nothing)
end

function Base.merge!(ss::MultiAlgRSampler...)
error("To Be Implemented")
end
function Base.merge!(ss::MultiAlgLSampler...)
error("To Be Implemented")
function Base.merge!(ss::MultiAlgRSampler...)
s1 = ss[1]
rest = ss[2:end]
s1.n > minimum(s.n for s in rest) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir")
newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeS(), value(s1), value.(rest)...)
for i in 1:length(newvalue)
@inbounds s1.value[i] = newvalue[i]
end
s1.seen_k += sum(getfield(s, :seen_k) for s in rest)
return s1
end
function Base.merge!(ss::MultiAlgLSampler...)
s1 = ss[1]
rest = ss[2:end]
s1.n > minimum(s.n for s in rest) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir")
newvalue = reduce_samples_hypergeometric(get_ps(ss...), [s.rng for s in ss], TypeS(), value(s1), value.(rest)...)
for i in 1:length(newvalue)
@inbounds s1.value[i] = newvalue[i]
end
s1.seen_k += sum(getfield(s, :seen_k) for s in rest)
# Reset state and skip_k for the merged sampler
s1.state = 0.0
s1.skip_k = 0
return s1
end
function Base.merge!(s1::MultiAlgRSWRSKIPSampler{<:Nothing}, ss::MultiAlgRSWRSKIPSampler...)
s1.n > minimum(s.n for s in ss) && error("The size of the mutated reservoir should be the minimum size between all merged reservoir")
Expand Down
59 changes: 55 additions & 4 deletions test/merge_tests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,54 @@
s_all = (s1, s2)
for (s, it) in zip(s_all, iters)
for x in it
m1 == AlgRSWRSKIP() ? fit!(s, x) : fit!(s, x, 1.0)
# Handle unweighted vs weighted algorithms
if m1 == AlgRSWRSKIP()
fit!(s, x)
else
fit!(s, x, 1.0)
end
end
end
s_merged = merge(s1, s2)
res[shuffle!(rng, value(s_merged))...] += 1
end
cases = (m1 == AlgRSWRSKIP() || m1 == AlgWRSWRSKIP()) ? 10^size : factorial(10)/factorial(10-size)
# Adjust expected number of cases for different algorithms
if m1 == AlgRSWRSKIP() || m1 == AlgWRSWRSKIP()
cases = 10^size
else
cases = factorial(10)/factorial(10-size)
end
ps_exact = [1/cases for _ in 1:cases]
count_est = [x for x in vec(res) if x != 0]
chisq_test = ChisqTest(count_est, ps_exact)
@test pvalue(chisq_test) > 0.05
end

# Separate basic tests for AlgR and AlgL (not statistical)
@testset "AlgR and AlgL basic merge tests" begin
for m in (AlgR(), AlgL())
s1 = ReservoirSampler{Int}(rng, size, m)
s2 = ReservoirSampler{Int}(rng, size, m)

# Add some data
for x in 1:2; fit!(s1, x); end
for x in 3:4; fit!(s2, x); end

# Test that merge works
merged = merge(s1, s2)
@test merged isa Union{StreamSampling.MultiAlgRSampler_Mut, StreamSampling.MultiAlgLSampler_Mut}
@test merged.n == size

# Test that merge! works
s3 = ReservoirSampler{Int}(rng, size, m)
s4 = ReservoirSampler{Int}(rng, size, m)
for x in 5:6; fit!(s3, x); end
for x in 7:8; fit!(s4, x); end

result = merge!(s3, s4)
@test result === s3
end
end
s1 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP())
s2 = ReservoirSampler{Int}(rng, 2, AlgRSWRSKIP())
s_all = (s1, s2)
Expand All @@ -39,8 +75,23 @@
for m in (AlgRSWRSKIP(), AlgWRSWRSKIP())
s1 = ReservoirSampler{Int}(rng, m)
s2 = ReservoirSampler{Int}(rng, m)
m == AlgRSWRSKIP() ? fit!(s1, 1) : fit!(s1, 1, 1.0)
m == AlgRSWRSKIP() ? fit!(s2, 2) : fit!(s2, 2, 1.0)
if m == AlgRSWRSKIP()
fit!(s1, 1)
fit!(s2, 2)
else
fit!(s1, 1, 1.0)
fit!(s2, 2, 1.0)
end
@test value(merge!(s1, s2)) in (1, 2)
end

# Test merge! for multi-element unweighted samplers (AlgR and AlgL)
for m in (AlgR(), AlgL())
s1 = ReservoirSampler{Int}(rng, 1, m) # Single element reservoir
s2 = ReservoirSampler{Int}(rng, 1, m)
fit!(s1, 1)
fit!(s2, 2)
result = value(merge!(s1, s2))
@test length(result) == 1 && result[1] in (1, 2)
end
end
Loading