From e586c5b2268bc6c3cd9057f8f048ef69b07b5280 Mon Sep 17 00:00:00 2001 From: ngiann Date: Thu, 27 Apr 2023 20:13:02 +0200 Subject: [PATCH] add mean field --- src/GaussianVariationalInference.jl | 6 +- src/VIdiag.jl | 161 +++++++++++++++++++++++++ src/{coreVIrank1.jl => VIrank1.jl} | 57 ++++----- src/interface.jl | 177 ++++++++++------------------ 4 files changed, 250 insertions(+), 151 deletions(-) create mode 100644 src/VIdiag.jl rename src/{coreVIrank1.jl => VIrank1.jl} (73%) diff --git a/src/GaussianVariationalInference.jl b/src/GaussianVariationalInference.jl index c2307ca..8407e5a 100644 --- a/src/GaussianVariationalInference.jl +++ b/src/GaussianVariationalInference.jl @@ -15,7 +15,8 @@ module GaussianVariationalInference include("interface.jl") include("VIfull.jl") - include("coreVIrank1.jl") + include("VIdiag.jl") + include("VIrank1.jl") include("entropy.jl") # include("VIdiag.jl") @@ -27,6 +28,7 @@ module GaussianVariationalInference # Utilities # include("util/report.jl") + include("util/pickoptimiser.jl") include("util/generatelatentZ.jl") include("util/defaultgradient.jl") include("util/verifygradient.jl") @@ -42,7 +44,7 @@ module GaussianVariationalInference - export VI, VIrank1 #, VIdiag, VIfixedcov, MVI, laplace + export VI, VIdiag, VIrank1 #, VIdiag, VIfixedcov, MVI, laplace export exampleproblem1 diff --git a/src/VIdiag.jl b/src/VIdiag.jl new file mode 100644 index 0000000..7259f23 --- /dev/null +++ b/src/VIdiag.jl @@ -0,0 +1,161 @@ +function coreVIdiag(logp::Function, μ₀::AbstractArray{T, 1}, Σ₀diag::AbstractArray{T, 1}; gradlogp = gradlogp, seed = seed, S = S, test_every = test_every, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every) where T + + D = length(μ₀) + + #---------------------------------------------------- + # generate latent variables + #---------------------------------------------------- + + Ztrain = generatelatentZ(S = S, D = D, seed = seed) + + + #---------------------------------------------------- + # Auxiliar function for handling parameters + #---------------------------------------------------- + + function unpack(param) + + @assert(length(param) == D+D) + + local μ = param[1:D] + + local Cdiag = reshape(param[D+1:D+D], D) + + return μ, Cdiag + + end + + + #---------------------------------------------------- + # Objective and gradient functions for Optim.optimize + #---------------------------------------------------- + + function minauxiliary(param) + + local μ, Cdiag = unpack(param) + + local ℓ = elbo(μ, Cdiag, Ztrain) + + update!(trackELBO; newelbo = ℓ, μ = μ, C = Cdiag) + + return -1.0 * ℓ # Optim.optimise is minimising + + end + + + function minauxiliary_grad(param) + + local μ, Cdiag = unpack(param) + + return -1.0 * elbo_grad(μ, Cdiag, Ztrain) # Optim.optimise is minimising + + end + + + #---------------------------------------------------- + # Functions for covariance and covariance root + #---------------------------------------------------- + + + function getcov(Cdiag) + + Diagonal(Cdiag.^2) + + end + + + function getcovroot(Cdiag) + + return Cdiag + + end + + + #---------------------------------------------------- + # Approximate evidence lower bound and its gradient + #---------------------------------------------------- + + function elbo(μ, Cdiag, Z) + + local aux = z -> logp(μ .+ Cdiag.*z) + + Transducers.foldxt(+, Map(aux), Z) / length(Z) + GaussianVariationalInference.entropy(Cdiag) + + end + + + function partial_elbo_grad(μ, Cdiag, z) + + local g = gradlogp(μ .+ Cdiag.*z) + + [g; vec(g.*z)] + + end + + + function elbo_grad(μ, Cdiag, Z) + + local aux = z -> partial_elbo_grad(μ, Cdiag, z) + + local gradμCdiag = Transducers.foldxt(+, Map(aux), Z) / length(Z) + + # entropy contribution to covariance + + gradμCdiag[D+1:end] .+= vec(1.0 ./ Cdiag) + + return gradμCdiag + + end + + + # Package Optim requires that function for gradient has following signature + + gradhelper(storage, param) = copyto!(storage, minauxiliary_grad(param)) + + + #---------------------------------------------------- + # Numerically verify gradient + #---------------------------------------------------- + + numerical_verification ? verifygradient(μ₀, Σ₀diag, elbo, minauxiliary_grad, unpack, Ztrain) : nothing + + + #---------------------------------------------------- + # Define callback function called at each iteration + #---------------------------------------------------- + + # We want to keep track of the best variational + # parameters encountered during the optimisation of + # the elbo. Unfortunately, the otherwise superb + # package Optim.jl does not provide a consistent way + # accross different optimisers to do this. + + + trackELBO = RecordELBOProgress(; μ = zeros(D), C = zeros(D), + Stest = Stest, + show_every = show_every, + test_every = test_every, + elbo = elbo, seed = seed) + + + + #---------------------------------------------------- + # Call optimiser to minimise *negative* elbo + #---------------------------------------------------- + + options = Optim.Options(extended_trace = false, store_trace = false, show_every = 1, show_trace = false, iterations = iterations, g_tol = 1e-6, callback = trackELBO) + + result = Optim.optimize(minauxiliary, gradhelper, [μ₀; vec(sqrt.(Σ₀diag))], optimiser, options) + + μopt, Copt = unpack(result.minimizer) + + + #---------------------------------------------------- + # Return results + #---------------------------------------------------- + + Σopt = getcov(Copt) + + return MvNormal(μopt, Σopt), elbo(μopt, Copt, Ztrain), Copt + +end diff --git a/src/coreVIrank1.jl b/src/VIrank1.jl similarity index 73% rename from src/coreVIrank1.jl rename to src/VIrank1.jl index 5d9ad10..fdac4bc 100644 --- a/src/coreVIrank1.jl +++ b/src/VIrank1.jl @@ -1,6 +1,9 @@ -function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractArray{T, 2}; gradlogp = gradlogp, seed = seed, S = S, test_every = test_every, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, transform = transform, seedtest = seedtest) where T +function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C::AbstractArray{T, 2}; gradlogp = gradlogp, seed = seed, S = S, test_every = test_every, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, transform = transform, seedtest = seedtest) where T + + D = length(μ₀) + + rg = MersenneTwister(seed) - D = length(μ₀); @assert(D == size(C₀, 1) == size(C₀, 2)) #---------------------------------------------------- # generate latent variables @@ -13,7 +16,7 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA # Define jacobian of transformation via AD #---------------------------------------------------- - jac_transform = transform == identity ? Matrix(I, D, D) : x -> ForwardDiff.jacobian(transform, x) + # jac_transform = transform == identity ? Matrix(I, D, D) : x -> ForwardDiff.jacobian(transform, x) #---------------------------------------------------- @@ -41,11 +44,9 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA local μ, u, v = unpack(param) - local C = getcovroot(C₀, u, v) + local ℓ = elbo(μ, u, v, Ztrain) - local ℓ = elbo(μ, C, Ztrain) - - update!(trackELBO; newelbo = ℓ, μ = μ, C = C) + update!(trackELBO; newelbo = ℓ, μ = μ, C = getcovroot(u, v)) return -1.0 * ℓ # Optim.optimise is minimising @@ -65,9 +66,9 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA # Functions for covariance and covariance root #---------------------------------------------------- - function getcov(C₀, u, v) + function getcov(u, v) - local aux = getcovroot(C₀, u, v) + local aux = getcovroot(u, v) local Σ = aux*aux' @@ -76,9 +77,9 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA end - function getcovroot(C₀, u, v) + function getcovroot(u, v) - C₀ + u*v' + C + u*v' end @@ -87,9 +88,10 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA # Approximate evidence lower bound and its gradient #---------------------------------------------------- - function elbo(μ, C, Z) + function elbo(μ, u, v, Z) + + local C = getcovroot(u, v) - local ℋ = GaussianVariationalInference.entropy(C) # if transform !== identity @@ -126,7 +128,7 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA function elbo_grad(μ, u, v, Z) - local C = getcovroot(C₀, u, v) + local C = getcovroot(u, v) local aux = z -> partial_elbo_grad(μ, C, u, v, z) @@ -152,26 +154,9 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA # Numerically verify gradient #---------------------------------------------------- - # COMMENT BACK IN AFTER VERIFICATION - #numerical_verification ? verifygradient(μ₀, Σ₀, elbo, minauxiliary_grad, unpack, Ztrain) : nothing - - # DELETE AFTER VERIFICATION - # let - - # local u,v = randn(D), randn(D) + numerical_verification ? verifygradient(μ₀, 1e-2*randn(rg, D), 1e-2*randn(rg, D), elbo, minauxiliary_grad, unpack, Ztrain) : nothing - # local angrad = minauxiliary_grad([μ₀;vec(u);vec(v)]) - - # adgrad = ForwardDiff.gradient(minauxiliary, [μ₀; vec(u);vec(v)]) - - # discrepancy = maximum(abs.(vec(adgrad) - vec(angrad))) - # display([angrad adgrad]) - - # @printf("Maximum absolute difference between AD and analytical gradient is %f\n", discrepancy) - - # end - #---------------------------------------------------- # Define callback function called at each iteration #---------------------------------------------------- @@ -196,7 +181,7 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA options = Optim.Options(extended_trace = false, store_trace = false, show_trace = false, iterations = iterations, g_tol = 1e-6, callback = trackELBO) - result = Optim.optimize(minauxiliary, gradhelper, [μ₀; 1e-2*randn(2D)], optimiser, options) + result = Optim.optimize(minauxiliary, gradhelper, [μ₀; 1e-2*randn(rg, 2D)], optimiser, options) μopt, uopt, vopt = unpack(result.minimizer) @@ -205,10 +190,8 @@ function coreVIrank1(logp::Function, μ₀::AbstractArray{T, 1}, C₀::AbstractA # Return results #---------------------------------------------------- - Copt = getcovroot(C₀, uopt, vopt) - - # Σopt = getcov(C₀, uopt, vopt) + Copt = getcovroot(uopt, vopt) - return μopt, Copt, elbo(μopt, Copt, Ztrain) + return MvNormal(μopt, getcov(uopt, vopt)), elbo(μopt, uopt, vopt, Ztrain), Copt end diff --git a/src/interface.jl b/src/interface.jl index 76c88a2..8d5320a 100644 --- a/src/interface.jl +++ b/src/interface.jl @@ -1,7 +1,7 @@ """ # Basic use: - q, logev = VI(logp, μ, σ²=0.1; S = 100, iterations = 1, show_every = -1) + q, logev, Croot = VI(logp, μ, σ²=0.1; S = 100, iterations = 1, show_every = -1) Returns approximate Gaussian posterior and log evidence. @@ -19,8 +19,9 @@ A description of only the most basic arguments follows. # Outputs -- `q` is the approximating posterior returned as a ```Distributions.MvNormal``` type +- `q` is the approximating posterior returned as a ```Distributions.MvNormal``` type. - `logev` is the approximate log-evidence. +- `Croot` is the matrix root of the posterior covariance. # Example @@ -31,129 +32,107 @@ julia> using LinearAlgebra, Distributions julia> D = 4; X = randn(D, 1000); W = randn(D); β = 0.3; α = 1.0; julia> Y = vec(W'*X); Y += randn(size(Y))/sqrt(β); julia> Sn = inv(α*I + β*(X*X')) ; mn = β*Sn*X*Y; # exact posterior -julia> posterior, logev = VI( w -> logpdf(MvNormal(vec(w'*X), sqrt(1/β)), Y) + logpdf(MvNormal(zeros(D),sqrt(1/α)), w), randn(D); S = 1_000, iterations = 15); +julia> posterior, logev, = VI( w -> logpdf(MvNormal(vec(w'*X), sqrt(1/β)), Y) + logpdf(MvNormal(zeros(D),sqrt(1/α)), w), randn(D); S = 1_000, iterations = 15); julia> display([mean(posterior) mn]) julia> display([cov(posterior) Sn]) julia> display(logev) # display negative log evidence ``` """ -function VI(logp::Function, μ::AbstractVector, Σ::AbstractMatrix; gradlogp = defaultgradient(μ), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) +function VI(logp::Function, μ::Vector, Σ::Matrix; gradlogp = defaultgradient(μ), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) # check validity of arguments - @argcheck seed > 0 + checkcommonarguments(seed, iterations, S, Stest, μ) - @argcheck iterations > 0 - - @argcheck S > 0 - - @argcheck size(Σ, 1) == size(Σ, 2) "Σ must be a square matrix" + @argcheck size(Σ, 1) == size(Σ, 2) "Σ must be a square matrix" @argcheck length(μ) == size(Σ, 1) == size(Σ, 2) "dimensions of μ do not agree with dimensions of Σ" - @argcheck isposdef(Σ) "Σ must be positive definite" - - @argcheck length(μ) >= 2 "VI works only for problems with two parameters and more" + @argcheck isposdef(Σ) "Σ must be positive definite" + + # pick optimiser and (re)define gradient of logp - # check gradient arguments + optimiser, gradlogp = pickoptimiser(μ, logp, gradlogp, gradientmode) - optimiser = NelderMead() # default optimiser + # Call actual algorithm - if gradientmode == :forward - - gradlogp = x -> ForwardDiff.gradient(logp, x) - - optimiser = LBFGS() # optimiser to be used with gradient calculated wiht automatic differentiation - - elseif gradientmode == :zygote - - gradlogp = x -> Zygote.gradient(logp, x)[1] - - optimiser = LBFGS() # optimiser to be used with gradient calculated wiht automatic differentiation - - elseif gradientmode == :provided + @printf("Running VI with full covariance: seed=%d, S=%d, Stest=%d, D=%d for %d iterations\n", seed, S, Stest, length(μ), iterations) - if any(isnan.(gradlogp(μ))) - - error("provided gradient returns NaN when evaluate at provided μ") + coreVIfull(logp, μ, Σ; gradlogp = gradlogp, seed = seed, S = S, optimiser=optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, test_every = test_every) - end +end - optimiser = LBFGS() # optimiser to be used with user provided gradient - elseif gradientmode == :gradientfree - - optimiser = NelderMead() # optimiser when no gradient provided +function VI(logp::Function, μ::Vector, σ² = 0.1; gradlogp = defaultgradient(μ), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int=0, show_every::Int = -1, test_every::Int = -1) - else + @argcheck σ² > 0 "σ² must be ≥ 0" - error("invalid specification of argument gradientmode") + Σ = Matrix(σ²*I, length(μ), length(μ)) # initial covariance - end + VI(logp, μ, Σ; gradlogp = gradlogp, gradientmode = gradientmode, seed = seed, S = S, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, test_every = test_every) +end - # Call actual algorithm - @printf("Running VI with full covariance: seed=%d, S=%d, Stest=%d, D=%d for %d iterations\n", seed, S, Stest, length(μ), iterations) +function VI(logp::Function, initgaussian::AbstractMvNormal; gradlogp = defaultgradient(mean(initgaussian)), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int = 1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) - coreVIfull(logp, μ, Σ; gradlogp = gradlogp, seed = seed, S = S, optimiser=optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, test_every = test_every) + VI(logp, mean(initgaussian), cov(initgaussian); gradlogp = gradlogp, gradientmode = gradientmode, seed = seed, S = S, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, test_every = test_every) end -function VI(logp::Function, μ::AbstractVector, σ² = 0.1; gradlogp = defaultgradient(μ), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int=0, show_every::Int = -1, test_every::Int = -1) - - @argcheck σ² > 0 - - # initial covariance - D = length(μ) - - Σ = Matrix(σ²*I, length(μ), length(μ)) - VI(logp, μ, Σ; gradlogp = gradlogp, gradientmode = gradientmode, seed = seed, S = S, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, test_every = test_every) +#-----------------------------------# +# Call mean field # +#-----------------------------------# -end +function VIdiag(logp::Function, μ::Vector, Σdiag::Vector = 0.1*ones(length(μ)); gradlogp = defaultgradient(μ), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) + # check validity of arguments -function VI(logp::Function, initgaussian::AbstractMvNormal; gradlogp = defaultgradient(mean(initgaussian)), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int = 1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) + checkcommonarguments(seed, iterations, S, Stest, μ) - VI(logp, mean(initgaussian), cov(initgaussian); gradlogp = gradlogp, gradientmode = gradientmode, seed = seed, S = S, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, test_every = test_every) + @argcheck length(Σdiag) == length(μ) "Σdiag must be a vector the of same length as mean μ" + + @argcheck isposdef(Diagonal(Σdiag)) "Σdiag must be positive definite" + -end + # pick optimiser and (re)define gradient of logp + optimiser, gradlogp = pickoptimiser(μ, logp, gradlogp, gradientmode) + # Call actual algorithm -# #-----------------------------------# -# # Call mean field # -# #-----------------------------------# + @printf("Running VI with diagonal covariance (mean field): seed=%d, S=%d, Stest=%d, D=%d for %d iterations\n", seed, S, Stest, length(μ), iterations) -# function VIdiag(logp::Function, μ::Array{Float64,1}, Σdiag = 0.1*ones(length(μ)); gradlogp = x -> ForwardDiff.gradient(logp, x), optimiser=Optim.LBFGS(), seed = 1, S = 100, iterations=1, numerical_verification = false, Stest=0, show_every=-1, inititerations=0) + coreVIdiag(logp, μ, Σdiag; gradlogp = gradlogp, seed = seed, S = S, test_every = test_every, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every) -# coreVIdiag(logp, [μ], [Σdiag]; gradlogp = gradlogp, seed = seed, S = S, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, inititerations=inititerations) +end -# end +function VIdiag(logp::Function, initgaussian::MvNormal; gradlogp = defaultgradient(μ), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) -# function VIdiag(logp::Function, initgaussian::MvNormal; gradlogp = x -> ForwardDiff.gradient(logp, x), optimiser=Optim.LBFGS(), seed = 1, S = 100, iterations=1, numerical_verification = false, Stest=0, show_every=-1, inititerations=0) + VIdiag(logp, mean(initgaussian), diag(cov(initgaussian)); gradlogp = gradlogp, gradientmode = gradientmode, seed = seed, S = S, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, test_every = test_every) -# VIdiag(logp, mean(initgaussian), diag(cov(initgaussian)); gradlogp = gradlogp, seed = seed, S = S, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, inititerations=inititerations) +end -# end -# function VIdiag(logp::Function, μ::Array{Array{Float64,1},1}, Σdiag = [0.1*ones(length(μ[1])) for _ in 1:length(μ)]; gradlogp = x -> ForwardDiff.gradient(logp, x), optimiser=Optim.LBFGS(), seed = 1, S = 100, iterations=1, numerical_verification = false, Stest=0, show_every=-1, inititerations=0) +function VIdiag(logp::Function, μ::Vector, σ²::Float64 = 0.1; gradlogp = defaultgradient(μ), gradientmode = :gradientfree, seed::Int = 1, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) + @argcheck σ² > 0 "σ² must be ≥ 0" -# coreVIdiag(logp, μ, Σdiag; gradlogp = gradlogp, seed = seed, S = S, optimiser = optimiser, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, inititerations=inititerations) + Σdiag = σ²*ones(length(μ)) # initial diagonal covariance as vector + VIdiag(logp, μ, Σdiag; gradlogp = gradlogp, gradientmode = gradientmode, seed = seed, S = S, iterations = iterations, numerical_verification = numerical_verification, Stest = Stest, show_every = show_every, test_every = test_every) -# end +end # #-----------------------------------# @@ -212,61 +191,21 @@ end -function VIrank1(logp::Function, μ::AbstractVector, C::AbstractMatrix; gradlogp = defaultgradient(μ), gradientmode = :gradientfree, transform = identity, seed::Int = 1, seedtest::Int = 2, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) - +function VIrank1(logp::Function, μ::Vector, C::Matrix; gradlogp = defaultgradient(μ), gradientmode = :gradientfree, transform = identity, seed::Int = 1, seedtest::Int = 2, S::Int = 100, iterations::Int=1, numerical_verification::Bool = false, Stest::Int = 0, show_every::Int = -1, test_every::Int = -1) # check validity of arguments - # @argcheck seed > 0 - - @argcheck iterations > 0 - - @argcheck S > 0 - + checkcommonarguments(seed, iterations, S, Stest, μ) + @argcheck size(C, 1) == size(C, 2) "C must be a square matrix" @argcheck length(μ) == size(C, 1) == size(C, 2) "dimensions of μ do not agree with dimensions of C" - @argcheck length(μ) >= 2 "VIrank1 works only for problems with two parameters and more" - - - # check gradient arguments - - optimiser = NelderMead() # default optimiser - - - if gradientmode == :forward - - gradlogp = x -> ForwardDiff.gradient(logp, x) - - optimiser = LBFGS() # optimiser to be used with gradient calculated wiht automatic differentiation - - elseif gradientmode == :zygote - - gradlogp = x -> Zygote.gradient(logp, x)[1] - - optimiser = LBFGS() # optimiser to be used with gradient calculated wiht automatic differentiation - - elseif gradientmode == :provided - - if any(isnan.(gradlogp(μ))) - - error("provided gradient returns NaN when evaluate at provided μ") - - end - - optimiser = LBFGS() # optimiser to be used with user provided gradient - - elseif gradientmode == :gradientfree - - optimiser = NelderMead() # optimiser when no gradient provided - else + # pick optimiser and (re)define gradient of logp - error("invalid specification of argument gradientmode") - - end + optimiser, gradlogp = pickoptimiser(μ, logp, gradlogp, gradientmode) # Call actual algorithm @@ -278,4 +217,18 @@ function VIrank1(logp::Function, μ::AbstractVector, C::AbstractMatrix; gradlogp end +function checkcommonarguments(seed, iterations, S, Stest, μ) + + # check validity of arguments + + @argcheck seed >= 0 "seed must be ≥ 0" + + @argcheck iterations > 0 "iterations must be > 0" + @argcheck S > 0 "S must be > 0" + + @argcheck Stest >= 0 "Stest must be ≥ 0" + + @argcheck length(μ) >= 2 "VI works only for problems with two parameters and more" + +end \ No newline at end of file