From 6bb3c35d20e94f1ae6e5570707ada1cbc3c7dbfb Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Thu, 13 Jan 2022 10:43:52 +1100 Subject: [PATCH 1/9] symmetric quadratic OT newton --- src/OptimalTransport.jl | 1 + src/quadratic_newton_symm.jl | 223 ++++++++++++++++++++++++++++++++++ src/quadratic_newton_symm.jl~ | 0 test.jl | 8 ++ 4 files changed, 232 insertions(+) create mode 100644 src/quadratic_newton_symm.jl create mode 100644 src/quadratic_newton_symm.jl~ create mode 100644 test.jl diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 1653431e..b471bfb4 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -39,6 +39,7 @@ include("entropic/sinkhorn_solve.jl") include("quadratic.jl") include("quadratic_newton.jl") +include("quadratic_newton_symm.jl") include("dual/entropic_dual.jl") diff --git a/src/quadratic_newton_symm.jl b/src/quadratic_newton_symm.jl new file mode 100644 index 00000000..30a362ff --- /dev/null +++ b/src/quadratic_newton_symm.jl @@ -0,0 +1,223 @@ +struct SymmetricQuadraticOTNewton{T<:Real,K<:Real,D<:Real} <: QuadraticOT + θ::T + κ::K + δ::D + armijo_max::Int +end + +function SymmetricQuadraticOTNewton(; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50) + return SymmetricQuadraticOTNewton(θ, κ, δ, armijo_max) +end + +struct SymmetricQuadraticOTNewtonCache{U,C,P,GT,X} + u::U + δu::U + σ::C + γ::P + G::GT + x::X + M::Int +end + +function build_cache( + ::Type{T}, + ::SymmetricQuadraticOTNewton, + μ::AbstractVector, + C::AbstractMatrix, + ε::Real, +) where {T} + # create and initialize dual potentials + u = similar(μ, T, size(μ, 1)) + fill!(u, zero(T)) + δu = similar(u, T) + # intermediate variables (don't need to be initialised) + σ = similar(C, T) + γ = similar(C, T) + M = size(μ, 1) + N = M + G = similar(u, T, M + N, M + N) + fill!(G, zero(T)) + # initial guess for conjugate gradient + x = similar(u, T, M + N) + fill!(x, zero(T)) + return SymmetricQuadraticOTNewtonCache(u, δu, σ, γ, G, x, M) +end + +function check_convergence( + μ::AbstractVector, + cache::SymmetricQuadraticOTNewtonCache, + convergence_cache::QuadraticOTConvergenceCache, + atol::Real, + rtol::Real, +) + γ = cache.γ + norm_diff = norm(vec(sum(γ; dims=2)) .- μ, Inf) + isconverged = + norm_diff < + max(atol, rtol * max(convergence_cache.norm_source, convergence_cache.norm_target)) + return isconverged, norm_diff +end + +function descent_dir!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) + # unpack solver + eps = solver.eps + C = solver.C + μ = solver.source + cache = solver.cache + # unpack cache + u = cache.u + δu = cache.δu + σ = cache.σ + γ = cache.γ + G = cache.G + x = cache.x + M = cache.M + N = M + # Armijo parameters + δ = solver.alg.δ + + # setup intermediate variables + @. γ = u + u' - C + @. σ = γ ≥ 0 + @. γ = NNlib.relu(γ) / eps + + # setup kernel matrix G + G = Diagonal(vec(sum(σ; dims=2))) + σ + δ*I + + # cg step + b = -eps * (vec(sum(γ; dims=2)) .- μ) + cg!(x, G, b) + δu .= x +end + +function descent_step!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) + # unpack solver + eps = solver.eps + C = solver.C + μ = solver.source + cache = solver.cache + # unpack cache + u = cache.u + δu = cache.δu + γ = cache.γ + + # Armijo parameters + θ = solver.alg.θ + κ = solver.alg.κ + armijo_max = solver.alg.armijo_max + armijo_counter = 0 + + # dual objective + function Φ(u, μ, C, ε) + return norm(NNlib.relu.(u .+ u' .- C))^2 / 2 - 2*ε * dot(μ, u) + end + + # compute directional derivative + d = -eps * (2*dot(δu, μ)) + eps * dot(γ, δu .+ δu') + t = 1 + Φ0 = Φ(u, μ, C, eps) + while (armijo_counter < armijo_max) && + (Φ(u + t * δu, μ, C, eps) ≥ Φ0 + t * θ * d) + t = κ * t + armijo_counter += 1 + end + return u .= u + t * δu +end + +function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) + # unpack solver + μ = solver.source + atol = solver.atol + rtol = solver.rtol + maxiter = solver.maxiter + check_convergence = solver.check_convergence + cache = solver.cache + convergence_cache = solver.convergence_cache + + isconverged = false + to_check_step = check_convergence + for iter in 1:maxiter + # compute descent direction + descent_dir!(solver) + # Newton step + descent_step!(solver) + # check source marginal + # always check convergence after the final iteration + to_check_step -= 1 + if to_check_step == 0 || iter == maxiter + # reset counter + to_check_step = check_convergence + + isconverged, abserror = OptimalTransport.check_convergence( + μ, μ, cache, convergence_cache, atol, rtol + ) + @debug string(solver.alg) * + " (" * + string(iter) * + "/" * + string(maxiter) * + ": absolute error of source marginal = " * + string(maximum(abserror)) + + if isconverged + @debug "$(solver.alg) ($iter/$maxiter): converged" + break + end + end + end + + if !isconverged + @warn "$(solver.alg) ($maxiter/$maxiter): not converged" + end + + return nothing +end + +function build_solver( + μ::AbstractVector, + C::AbstractMatrix, + ε::Real, + alg::QuadraticOT; + atol=nothing, + rtol=nothing, + check_convergence=1, + maxiter::Int=100, +) + # check that source and target marginals have the correct size + checksize(μ, μ, C) + # do not use checksize2 since for quadratic OT (at least for now) we do not support batch computations + + # compute type + T = float(Base.promote_eltype(μ, one(eltype(C)) / ε)) + + # build caches + cache = build_cache(T, alg, μ, C, ε) + convergence_cache = build_convergence_cache(T, μ, μ) + + # set tolerances + _atol = atol === nothing ? 0 : atol + _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol + + # create solver + solver = QuadraticOTSolver( + μ, μ, C, ε, alg, _atol, _rtol, maxiter, check_convergence, cache, convergence_cache + ) + return solver +end + + +# interface + +function quadreg(μ, C, ε, alg::SymmetricQuadraticOTNewton; kwargs...) + solver = build_solver(μ, C, ε, alg; kwargs...) + solve!(solver) + γ = plan(solver) + return γ +end + +function plan(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) + cache = solver.cache + γ = NNlib.relu.(cache.u .+ cache.u' .- solver.C) / solver.eps + return γ +end + diff --git a/src/quadratic_newton_symm.jl~ b/src/quadratic_newton_symm.jl~ new file mode 100644 index 00000000..e69de29b diff --git a/test.jl b/test.jl new file mode 100644 index 00000000..341f49a1 --- /dev/null +++ b/test.jl @@ -0,0 +1,8 @@ +using OptimalTransport + +μ = rand(10, 5) +ν = rand(10, 5) +C = rand(10, 10) + +sinkhorn_unbalanced2(μ, ν, C, 1.0, 1.0, 1.0) + From 7dc5a51512a12f16214722f7c2036d8f93a697c9 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Thu, 13 Jan 2022 10:44:40 +1100 Subject: [PATCH 2/9] update version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 5c16c2db..4fc0ec72 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "OptimalTransport" uuid = "7e02d93a-ae51-4f58-b602-d97af76e3b33" authors = ["zsteve "] -version = "0.3.18" +version = "0.3.19" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" From d9ce05a81b63de0e9fdf22e54260a1eec943a3ab Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Thu, 13 Jan 2022 10:54:02 +1100 Subject: [PATCH 3/9] fix bug --- src/quadratic_newton_symm.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/quadratic_newton_symm.jl b/src/quadratic_newton_symm.jl index 30a362ff..775f6d30 100644 --- a/src/quadratic_newton_symm.jl +++ b/src/quadratic_newton_symm.jl @@ -35,10 +35,10 @@ function build_cache( γ = similar(C, T) M = size(μ, 1) N = M - G = similar(u, T, M + N, M + N) + G = similar(u, T, M, M) fill!(G, zero(T)) # initial guess for conjugate gradient - x = similar(u, T, M + N) + x = similar(u, T, M) fill!(x, zero(T)) return SymmetricQuadraticOTNewtonCache(u, δu, σ, γ, G, x, M) end From 06f0680a49988ab5d6b4f8a8d648c4f424891089 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Thu, 13 Jan 2022 10:54:02 +1100 Subject: [PATCH 4/9] fix bug --- src/quadratic_newton_symm.jl | 68 +++++++++++++++++------------------- 1 file changed, 32 insertions(+), 36 deletions(-) diff --git a/src/quadratic_newton_symm.jl b/src/quadratic_newton_symm.jl index 30a362ff..1e9284fc 100644 --- a/src/quadratic_newton_symm.jl +++ b/src/quadratic_newton_symm.jl @@ -9,43 +9,9 @@ function SymmetricQuadraticOTNewton(; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50) return SymmetricQuadraticOTNewton(θ, κ, δ, armijo_max) end -struct SymmetricQuadraticOTNewtonCache{U,C,P,GT,X} - u::U - δu::U - σ::C - γ::P - G::GT - x::X - M::Int -end - -function build_cache( - ::Type{T}, - ::SymmetricQuadraticOTNewton, - μ::AbstractVector, - C::AbstractMatrix, - ε::Real, -) where {T} - # create and initialize dual potentials - u = similar(μ, T, size(μ, 1)) - fill!(u, zero(T)) - δu = similar(u, T) - # intermediate variables (don't need to be initialised) - σ = similar(C, T) - γ = similar(C, T) - M = size(μ, 1) - N = M - G = similar(u, T, M + N, M + N) - fill!(G, zero(T)) - # initial guess for conjugate gradient - x = similar(u, T, M + N) - fill!(x, zero(T)) - return SymmetricQuadraticOTNewtonCache(u, δu, σ, γ, G, x, M) -end - function check_convergence( μ::AbstractVector, - cache::SymmetricQuadraticOTNewtonCache, + cache::QuadraticOTNewtonCache, convergence_cache::QuadraticOTConvergenceCache, atol::Real, rtol::Real, @@ -173,6 +139,36 @@ function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) return nothing end +function build_cache( + ::Type{T}, + ::SymmetricQuadraticOTNewton, + μ::AbstractVector, + ν::AbstractVector, + C::AbstractMatrix, + ε::Real, +) where {T} + # create and initialize dual potentials + u = similar(μ, T, size(μ, 1)) + v = similar(ν, T, size(ν, 1)) + fill!(u, zero(T)) + fill!(v, zero(T)) + δu = similar(u, T) + δv = similar(v, T) + # intermediate variables (don't need to be initialised) + σ = similar(C, T) + γ = similar(C, T) + M = size(μ, 1) + N = size(ν, 1) + G = similar(u, T, M + N, M + N) + fill!(G, zero(T)) + # initial guess for conjugate gradient + x = similar(u, T, M + N) + fill!(x, zero(T)) + return QuadraticOTNewtonCache(u, v, δu, δv, σ, γ, G, x, M, N) +end + + + function build_solver( μ::AbstractVector, C::AbstractMatrix, @@ -191,7 +187,7 @@ function build_solver( T = float(Base.promote_eltype(μ, one(eltype(C)) / ε)) # build caches - cache = build_cache(T, alg, μ, C, ε) + cache = build_cache(T, alg, μ, μ, C, ε) convergence_cache = build_convergence_cache(T, μ, μ) # set tolerances From 0e1cf68c75b3dd56c7660d4529ec849ee69a9ca5 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Thu, 27 Jan 2022 10:20:46 +1100 Subject: [PATCH 5/9] update --- src/quadratic_newton_symm.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/quadratic_newton_symm.jl b/src/quadratic_newton_symm.jl index 1e9284fc..4fceef92 100644 --- a/src/quadratic_newton_symm.jl +++ b/src/quadratic_newton_symm.jl @@ -159,10 +159,10 @@ function build_cache( γ = similar(C, T) M = size(μ, 1) N = size(ν, 1) - G = similar(u, T, M + N, M + N) + G = similar(u, T, M, M) fill!(G, zero(T)) # initial guess for conjugate gradient - x = similar(u, T, M + N) + x = similar(u, T, M) fill!(x, zero(T)) return QuadraticOTNewtonCache(u, v, δu, δv, σ, γ, G, x, M, N) end From 8f7f899ba6b2bba8d9dd38e52286017d834ceaac Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 25 Jun 2023 13:14:20 +1000 Subject: [PATCH 6/9] test implementation of active set method for QROT in symmetric setting --- Project.toml | 4 + src/#quadratic_newton_symm.jl# | 219 --------------------------------- src/OptimalTransport.jl | 1 + src/quadratic_newton_symm.jl | 5 + test.jl | 105 +++++++++++++++- 5 files changed, 111 insertions(+), 223 deletions(-) delete mode 100644 src/#quadratic_newton_symm.jl# diff --git a/Project.toml b/Project.toml index 3c57e214..bb55895e 100644 --- a/Project.toml +++ b/Project.toml @@ -6,10 +6,14 @@ version = "0.3.20" [deps] ExactOptimalTransport = "24df6009-d856-477c-ac5c-91f668376b31" IterativeSolvers = "42fd0dbc-a981-5370-80f2-aaf504508153" +LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" +RandomMatrix = "0af1cf96-9b30-454e-9d9e-87908f700846" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" +SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf" [compat] ExactOptimalTransport = "0.1, 0.2" diff --git a/src/#quadratic_newton_symm.jl# b/src/#quadratic_newton_symm.jl# deleted file mode 100644 index 1e9284fc..00000000 --- a/src/#quadratic_newton_symm.jl# +++ /dev/null @@ -1,219 +0,0 @@ -struct SymmetricQuadraticOTNewton{T<:Real,K<:Real,D<:Real} <: QuadraticOT - θ::T - κ::K - δ::D - armijo_max::Int -end - -function SymmetricQuadraticOTNewton(; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50) - return SymmetricQuadraticOTNewton(θ, κ, δ, armijo_max) -end - -function check_convergence( - μ::AbstractVector, - cache::QuadraticOTNewtonCache, - convergence_cache::QuadraticOTConvergenceCache, - atol::Real, - rtol::Real, -) - γ = cache.γ - norm_diff = norm(vec(sum(γ; dims=2)) .- μ, Inf) - isconverged = - norm_diff < - max(atol, rtol * max(convergence_cache.norm_source, convergence_cache.norm_target)) - return isconverged, norm_diff -end - -function descent_dir!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) - # unpack solver - eps = solver.eps - C = solver.C - μ = solver.source - cache = solver.cache - # unpack cache - u = cache.u - δu = cache.δu - σ = cache.σ - γ = cache.γ - G = cache.G - x = cache.x - M = cache.M - N = M - # Armijo parameters - δ = solver.alg.δ - - # setup intermediate variables - @. γ = u + u' - C - @. σ = γ ≥ 0 - @. γ = NNlib.relu(γ) / eps - - # setup kernel matrix G - G = Diagonal(vec(sum(σ; dims=2))) + σ + δ*I - - # cg step - b = -eps * (vec(sum(γ; dims=2)) .- μ) - cg!(x, G, b) - δu .= x -end - -function descent_step!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) - # unpack solver - eps = solver.eps - C = solver.C - μ = solver.source - cache = solver.cache - # unpack cache - u = cache.u - δu = cache.δu - γ = cache.γ - - # Armijo parameters - θ = solver.alg.θ - κ = solver.alg.κ - armijo_max = solver.alg.armijo_max - armijo_counter = 0 - - # dual objective - function Φ(u, μ, C, ε) - return norm(NNlib.relu.(u .+ u' .- C))^2 / 2 - 2*ε * dot(μ, u) - end - - # compute directional derivative - d = -eps * (2*dot(δu, μ)) + eps * dot(γ, δu .+ δu') - t = 1 - Φ0 = Φ(u, μ, C, eps) - while (armijo_counter < armijo_max) && - (Φ(u + t * δu, μ, C, eps) ≥ Φ0 + t * θ * d) - t = κ * t - armijo_counter += 1 - end - return u .= u + t * δu -end - -function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) - # unpack solver - μ = solver.source - atol = solver.atol - rtol = solver.rtol - maxiter = solver.maxiter - check_convergence = solver.check_convergence - cache = solver.cache - convergence_cache = solver.convergence_cache - - isconverged = false - to_check_step = check_convergence - for iter in 1:maxiter - # compute descent direction - descent_dir!(solver) - # Newton step - descent_step!(solver) - # check source marginal - # always check convergence after the final iteration - to_check_step -= 1 - if to_check_step == 0 || iter == maxiter - # reset counter - to_check_step = check_convergence - - isconverged, abserror = OptimalTransport.check_convergence( - μ, μ, cache, convergence_cache, atol, rtol - ) - @debug string(solver.alg) * - " (" * - string(iter) * - "/" * - string(maxiter) * - ": absolute error of source marginal = " * - string(maximum(abserror)) - - if isconverged - @debug "$(solver.alg) ($iter/$maxiter): converged" - break - end - end - end - - if !isconverged - @warn "$(solver.alg) ($maxiter/$maxiter): not converged" - end - - return nothing -end - -function build_cache( - ::Type{T}, - ::SymmetricQuadraticOTNewton, - μ::AbstractVector, - ν::AbstractVector, - C::AbstractMatrix, - ε::Real, -) where {T} - # create and initialize dual potentials - u = similar(μ, T, size(μ, 1)) - v = similar(ν, T, size(ν, 1)) - fill!(u, zero(T)) - fill!(v, zero(T)) - δu = similar(u, T) - δv = similar(v, T) - # intermediate variables (don't need to be initialised) - σ = similar(C, T) - γ = similar(C, T) - M = size(μ, 1) - N = size(ν, 1) - G = similar(u, T, M + N, M + N) - fill!(G, zero(T)) - # initial guess for conjugate gradient - x = similar(u, T, M + N) - fill!(x, zero(T)) - return QuadraticOTNewtonCache(u, v, δu, δv, σ, γ, G, x, M, N) -end - - - -function build_solver( - μ::AbstractVector, - C::AbstractMatrix, - ε::Real, - alg::QuadraticOT; - atol=nothing, - rtol=nothing, - check_convergence=1, - maxiter::Int=100, -) - # check that source and target marginals have the correct size - checksize(μ, μ, C) - # do not use checksize2 since for quadratic OT (at least for now) we do not support batch computations - - # compute type - T = float(Base.promote_eltype(μ, one(eltype(C)) / ε)) - - # build caches - cache = build_cache(T, alg, μ, μ, C, ε) - convergence_cache = build_convergence_cache(T, μ, μ) - - # set tolerances - _atol = atol === nothing ? 0 : atol - _rtol = rtol === nothing ? (_atol > zero(_atol) ? zero(T) : sqrt(eps(T))) : rtol - - # create solver - solver = QuadraticOTSolver( - μ, μ, C, ε, alg, _atol, _rtol, maxiter, check_convergence, cache, convergence_cache - ) - return solver -end - - -# interface - -function quadreg(μ, C, ε, alg::SymmetricQuadraticOTNewton; kwargs...) - solver = build_solver(μ, C, ε, alg; kwargs...) - solve!(solver) - γ = plan(solver) - return γ -end - -function plan(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) - cache = solver.cache - γ = NNlib.relu.(cache.u .+ cache.u' .- solver.C) / solver.eps - return γ -end - diff --git a/src/OptimalTransport.jl b/src/OptimalTransport.jl index 0af9733a..a7385391 100644 --- a/src/OptimalTransport.jl +++ b/src/OptimalTransport.jl @@ -13,6 +13,7 @@ using LinearAlgebra using IterativeSolvers using LogExpFunctions: LogExpFunctions using NNlib: NNlib +using SparseArrays export SinkhornGibbs, SinkhornStabilized, SinkhornEpsilonScaling export SinkhornBarycenterGibbs diff --git a/src/quadratic_newton_symm.jl b/src/quadratic_newton_symm.jl index 4fceef92..76b86dac 100644 --- a/src/quadratic_newton_symm.jl +++ b/src/quadratic_newton_symm.jl @@ -5,6 +5,11 @@ struct SymmetricQuadraticOTNewton{T<:Real,K<:Real,D<:Real} <: QuadraticOT armijo_max::Int end +struct SymmetricQuadraticOTNewtonActiveSet{A<:SymmetricQuadraticOTNewton, ST<:AbstractSparseArray} + alg::A + S::ST +end + function SymmetricQuadraticOTNewton(; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50) return SymmetricQuadraticOTNewton(θ, κ, δ, armijo_max) end diff --git a/test.jl b/test.jl index 341f49a1..0b34cb1d 100644 --- a/test.jl +++ b/test.jl @@ -1,8 +1,105 @@ +using Revise +using Pkg +Pkg.activate(".") using OptimalTransport +using LinearAlgebra +using SparseArrays +using Distances +using StatsBase -μ = rand(10, 5) -ν = rand(10, 5) -C = rand(10, 10) +N = 500 +d = 500 +μ_spt = randn(N, d) +X = μ_spt +μ = ones(N) +C = sum(X.^2 ; dims = 2)/2 .+ sum(X.^2 ; dims = 2)'/2 - X * X' +C[diagind(C)] .= Inf -sinkhorn_unbalanced2(μ, ν, C, 1.0, 1.0, 1.0) +ε = 5.0 +π = sparse(quadreg(μ, C, ε, OptimalTransport.SymmetricQuadraticOTNewton(); maxiter = 25)) +using RandomMatrix + +# KNN initialisation +using NearestNeighbors +function knn_adj(X, k) + # indices, _ = knn_matrices(nndescent(X, k, Euclidean())); + indices, _ = knn(KDTree(X), X, k); + A = spzeros(size(X, 2), size(X, 2)); + @inbounds for i = 1:size(A, 1) + A[i, i] = 1 + @inbounds for j in indices[i] + A[i, j] = 1 + end + end + return A +end + +S = knn_adj(X', 25) +S += sign.(mean([randPermutation(N) for _ in 1:25])) +S = sign.(S) +S[diagind(S)] .= 0 +dropzeros!(S) + +using NNlib +using LazyArrays +using IterativeSolvers + +D = sum(X.^2 ; dims = 2)/2 + +## start +u = zeros(N) + +function Φ(u, μ, C, ε, S) + return norm(NNlib.relu.((u .* S) .+ (u' .* S) .- C))^2 / 2 - 2*ε * dot(μ, u) +end + +armijo_max = 25 +for it = 1:10 + @info it + Csp = similar(S); + I, J, V = findnz(Csp); + for (i, j) in zip(I, J) Csp[i, j] = D[i] + D[j] - dot(X[i, :], X[j, :]) end + + γ = (u .* S) + (u' .* S) - Csp + σ = similar(S) + I, J, V = findnz(σ) + for (i, j) in zip(I, J) σ[i, j] = (γ[i, j] ≥ 0) end + dropzeros!(σ) + γ = relu.(γ) / ε + δ = 1e-5 + + G = σ + Diagonal(vec(sum(σ; dims = 2)) .+ δ) + b = -ε*(vec(sum(γ; dims = 2)) .- 1) + + δu = similar(u) + fill!(δu, 0) + cg!(δu, G, b) + + + d = ε*sum(γ .* ((δu .* S) + (δu' .* S))) - 2ε*dot(δu, μ) + t = 1 + Φ0 = Φ(u, μ, Csp, ε, S) + armijo_counter = 0 + θ = alg.alg.θ + κ = alg.alg.κ + while (armijo_counter < armijo_max) && + (Φ(u + t * δu, μ, Csp, ε, S) ≥ Φ0 + t * θ * d) + t = κ * t + armijo_counter += 1 + end + u .= u + t * δu +end +γ = relu.(u .+ u' - C) / ε +S = max.(S, sign.(γ)) +norm(γ - π, 1) / norm(γ, 1) + +sum(γ; dims = 1) +sum(γ; dims = 2) + +sparse(γ) + + + +using Plots +plot(heatmap(Array(γ)), heatmap(Array(π))) From e757f7916ec10351f0aaf65c7d6b5ae46169fe20 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Sun, 25 Jun 2023 16:37:26 +1000 Subject: [PATCH 7/9] draft implementation of active set method for QROT --- Project.toml | 1 + src/quadratic_newton_symm.jl | 223 ++++++++++++++++++++++++++++++++++- test.jl | 113 ++++++++++++------ 3 files changed, 296 insertions(+), 41 deletions(-) diff --git a/Project.toml b/Project.toml index bb55895e..79d9e5fe 100644 --- a/Project.toml +++ b/Project.toml @@ -10,6 +10,7 @@ LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" LogExpFunctions = "2ab3a3ac-af41-5b50-aa03-7779005ae688" NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd" +NearestNeighborDescent = "dd2c4c9e-a32f-5b2f-b342-08c2f244fce8" NearestNeighbors = "b8a86587-4115-5ab1-83bc-aa920d37bbce" RandomMatrix = "0af1cf96-9b30-454e-9d9e-87908f700846" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" diff --git a/src/quadratic_newton_symm.jl b/src/quadratic_newton_symm.jl index 76b86dac..9f04e201 100644 --- a/src/quadratic_newton_symm.jl +++ b/src/quadratic_newton_symm.jl @@ -5,15 +5,27 @@ struct SymmetricQuadraticOTNewton{T<:Real,K<:Real,D<:Real} <: QuadraticOT armijo_max::Int end -struct SymmetricQuadraticOTNewtonActiveSet{A<:SymmetricQuadraticOTNewton, ST<:AbstractSparseArray} - alg::A +struct SymmetricQuadraticOTNewtonAS{T<:Real,K<:Real,D<:Real,ST<:AbstractSparseArray, IT, JT} <: QuadraticOT + θ::T + κ::K + δ::D + armijo_max::Int S::ST + I::IT + J::JT end +Base.show(io::IO, ::SymmetricQuadraticOTNewtonAS) = print(io, "Symmetric semi-smooth Newton algorithm (active set)") + function SymmetricQuadraticOTNewton(; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50) return SymmetricQuadraticOTNewton(θ, κ, δ, armijo_max) end +function SymmetricQuadraticOTNewtonAS(S; θ=0.1, κ=0.5, δ=1e-5, armijo_max=50) + I, J, V = findnz(S) + return SymmetricQuadraticOTNewtonAS(θ, κ, δ, armijo_max, S, I, J) +end + function check_convergence( μ::AbstractVector, cache::QuadraticOTNewtonCache, @@ -61,6 +73,45 @@ function descent_dir!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) δu .= x end +function descent_dir!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}) + # unpack solver + eps = solver.eps + C = solver.C + μ = solver.source + cache = solver.cache + # unpack cache + u = cache.u + δu = cache.δu + σ = cache.σ + γ = cache.γ + G = cache.G + x = cache.x + M = cache.M + N = M + # Armijo parameters + δ = solver.alg.δ + S = solver.alg.S + I = solver.alg.I + J = solver.alg.J + + # setup intermediate variables + @. γ.nzval = u[I] + u[J] - C.nzval + + @. σ.nzval = γ.nzval ≥ 0 + @. γ.nzval = NNlib.relu(γ.nzval) / eps + + # setup kernel matrix G + # G = Diagonal(vec(sum(σ; dims=2))) + σ + δ*I + fill!(G.nzval, 0) + G += σ + G += Diagonal(vec(sum(σ; dims=2)) .+ δ) + + # cg step + b = -eps * (vec(sum(γ; dims=2)) .- μ) + cg!(x, G, b) + δu .= x +end + function descent_step!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) # unpack solver eps = solver.eps @@ -95,6 +146,44 @@ function descent_step!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) return u .= u + t * δu end +function descent_step!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}) + # unpack solver + eps = solver.eps + C = solver.C + μ = solver.source + cache = solver.cache + # unpack cache + u = cache.u + δu = cache.δu + γ = cache.γ + S = solver.alg.S + I = solver.alg.I + J = solver.alg.J + + # Armijo parameters + θ = solver.alg.θ + κ = solver.alg.κ + armijo_max = solver.alg.armijo_max + armijo_counter = 0 + + # dual objective + function Φ(u, μ, C, ε, I, J) + return norm(NNlib.relu.(u[I] + u[J] - C.nzval))^2 / 2 - 2*ε * dot(μ, u) + end + + # compute directional derivative + d = -eps * (2*dot(δu, μ)) + eps * dot(γ.nzval, δu[I] + δu[J]) + t = 1 + Φ0 = Φ(u, μ, C, eps, I, J) + while (armijo_counter < armijo_max) && + (Φ(u + t * δu, μ, C, eps, I, J) ≥ Φ0 + t * θ * d) + t = κ * t + armijo_counter += 1 + end + return u .= u + t * δu +end + + function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) # unpack solver μ = solver.source @@ -144,6 +233,68 @@ function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) return nothing end + +function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}) + # unpack solver + μ = solver.source + atol = solver.atol + rtol = solver.rtol + maxiter = solver.maxiter + check_convergence = solver.check_convergence + cache = solver.cache + convergence_cache = solver.convergence_cache + u_prev = similar(solver.cache.u) + copy!(u_prev, solver.cache.u) + + function check_convergence_dual(u, u_prev, atol, rtol) + norm_diff = norm(u - u_prev) + isconverged = + norm_diff < + max(atol, rtol * max(norm(u), norm(u_prev))) + return isconverged, norm_diff + end + + isconverged = false + to_check_step = check_convergence + for iter in 1:maxiter + # compute descent direction + descent_dir!(solver) + # Newton step + descent_step!(solver) + # check source marginal + # always check convergence after the final iteration + to_check_step -= 1 + if to_check_step == 0 || iter == maxiter + # reset counter + to_check_step = check_convergence + + # isconverged, abserror = OptimalTransport.check_convergence( + # μ, μ, cache, convergence_cache, atol, rtol + # ) + isconverged, abserror = check_convergence_dual(solver.cache.u, u_prev, atol, rtol) + @debug string(solver.alg) * + " (" * + string(iter) * + "/" * + string(maxiter) * + ": absolute error of source marginal = " * + string(maximum(abserror)) + + if isconverged + @debug "$(solver.alg) ($iter/$maxiter): converged" + break + end + copy!(u_prev, solver.cache.u) + end + end + + if !isconverged + @warn "$(solver.alg) ($maxiter/$maxiter): not converged" + end + + return nothing +end + function build_cache( ::Type{T}, ::SymmetricQuadraticOTNewton, @@ -172,7 +323,33 @@ function build_cache( return QuadraticOTNewtonCache(u, v, δu, δv, σ, γ, G, x, M, N) end - +function build_cache( + ::Type{T}, + alg::SymmetricQuadraticOTNewtonAS, + μ::AbstractVector, + ν::AbstractVector, + C::AbstractMatrix, + ε::Real + ) where {T} + # create and initialize dual potentials + u = similar(μ, T, size(μ, 1)) + v = similar(ν, T, size(ν, 1)) + fill!(u, zero(T)) + fill!(v, zero(T)) + δu = similar(u, T) + δv = similar(v, T) + # intermediate variables (don't need to be initialised) + σ = similar(C, T) + γ = similar(C, T) + M = size(μ, 1) + N = size(ν, 1) + # G = similar(u, T, M, M) + G = similar(alg.S + sparse(SparseArrays.I, M, M), T) + # initial guess for conjugate gradient + x = similar(u, T, M) + fill!(x, zero(T)) + return QuadraticOTNewtonCache(u, v, δu, δv, σ, γ, G, x, M, N) +end function build_solver( μ::AbstractVector, @@ -216,9 +393,49 @@ function quadreg(μ, C, ε, alg::SymmetricQuadraticOTNewton; kwargs...) return γ end +function quadreg(μ, C, ε, alg::SymmetricQuadraticOTNewtonAS; maxiter_as = 5, kwargs...) + function check_convergence(γ, cache, convergence_cache, atol, rtol) + norm_diff = norm(vec(sum(γ; dims=2)) .- μ, Inf) + isconverged = + norm_diff < + max(atol, rtol * max(convergence_cache.norm_source, convergence_cache.norm_target)) + return isconverged, norm_diff + end + S = alg.S + I = alg.I + J = alg.J + γ = spzeros(size(C)...) + u = similar(μ) + fill!(u, 0) + for iter = 1:maxiter_as + _alg = SymmetricQuadraticOTNewtonAS(alg.θ, alg.κ, alg.δ, alg.armijo_max, S, I, J) + Csp_v = similar(C, length(_alg.I)); @inbounds for k = 1:length(_alg.I) Csp_v[k] = C[_alg.I[k], _alg.J[k]] end + solver = build_solver(μ, sparse(_alg.I, _alg.J, Csp_v), ε, _alg; kwargs...) + copy!(solver.cache.u, u) + solve!(solver) + copy!(u, solver.cache.u) + γ = plan(solver, C) + isconverged, norm_diff = check_convergence(γ, solver.cache, solver.convergence_cache, solver.atol, solver.rtol) + if isconverged + @debug "$(solver.alg) AS step ($iter/$maxiter_as): converged" + break + else + S = max.(_alg.S, sign.(γ)) + I, J, _ = findnz(S) + @debug "$(solver.alg) AS growing support to $(nnz(S)) (sparsity $(nnz(S)/length(S)))" + end + end + return γ +end + function plan(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewton}) cache = solver.cache γ = NNlib.relu.(cache.u .+ cache.u' .- solver.C) / solver.eps return γ end +function plan(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}, Cfull) + cache = solver.cache + γ = sparse(NNlib.relu.(cache.u .+ cache.u' .- Cfull) / solver.eps) + return γ +end diff --git a/test.jl b/test.jl index 0b34cb1d..af7f7fd6 100644 --- a/test.jl +++ b/test.jl @@ -6,22 +6,28 @@ using LinearAlgebra using SparseArrays using Distances using StatsBase +using NNlib +using LazyArrays +using IterativeSolvers -N = 500 -d = 500 +N = 10_000 +d = 250 μ_spt = randn(N, d) X = μ_spt μ = ones(N) C = sum(X.^2 ; dims = 2)/2 .+ sum(X.^2 ; dims = 2)'/2 - X * X' +Cmean = mean(C) C[diagind(C)] .= Inf +ENV["JULIA_DEBUG"] = "OptimalTransport" -ε = 5.0 -π = sparse(quadreg(μ, C, ε, OptimalTransport.SymmetricQuadraticOTNewton(); maxiter = 25)) +ε = 1.0 +π = sparse(quadreg(μ, C / Cmean, ε, OptimalTransport.SymmetricQuadraticOTNewton(); maxiter = 25)) -using RandomMatrix +nnz(π) / N^2 -# KNN initialisation +using NearestNeighborDescent using NearestNeighbors + function knn_adj(X, k) # indices, _ = knn_matrices(nndescent(X, k, Euclidean())); indices, _ = knn(KDTree(X), X, k); @@ -35,56 +41,87 @@ function knn_adj(X, k) return A end -S = knn_adj(X', 25) -S += sign.(mean([randPermutation(N) for _ in 1:25])) -S = sign.(S) -S[diagind(S)] .= 0 -dropzeros!(S) +S = knn_adj(X', 50); +S += sign.(mean([randPermutation(N) for _ in 1:50])); +S = sign.(S); +S[diagind(S)] .= 0; +dropzeros!(S); -using NNlib -using LazyArrays -using IterativeSolvers +nnz(S) / N^2 -D = sum(X.^2 ; dims = 2)/2 +ε = 1 +alg = OptimalTransport.SymmetricQuadraticOTNewtonAS(S) +γ = quadreg(μ, C / Cmean, ε, alg; maxiter = 100) + +sum(γ; dims = 2) +sum(γ; dims = 1) +diag(γ) + +norm(π - γ, 1) + +D = sum(X.^2 ; dims = 2)/2 +alg = OptimalTransport.SymmetricQuadraticOTNewtonAS(S) +# form partial cost matrix +I, J, V = findnz(S) +C = D[I] + D[J]; @time @inbounds for k = 1:length(V) C[k] += -dot(X[I[k], :], X[J[k], :]) end +C_sp = sparse(I, J, relu.(C)) +solver = OptimalTransport.build_solver(μ, C_sp, ε, alg; maxiter = 50) +OptimalTransport.solve!(solver); +# form full cost matrix +Cfull = sum(X.^2 ; dims = 2)/2 .+ sum(X.^2 ; dims = 2)'/2 - X * X'; +Cfull[diagind(Cfull)] .= Inf; +plan = OptimalTransport.plan(solver, Cfull) +S = max.(S, sign.(plan)) + +norm(sum(plan; dims = 1) .- 1) +norm(sum(plan; dims = 2) .- 1) + +nnz(plan) / N^2 + +plot(heatmap(Array(plan)), heatmap(Array(π))) + +""" +D = sum(X.^2 ; dims = 2)/2 +Csp = D[I] + D[J]; @time @inbounds for k = 1:length(V) Csp[k] += -dot(X[I[k], :], X[J[k], :]) end +C = sum(X.^2 ; dims = 2)/2 .+ sum(X.^2 ; dims = 2)'/2 - X * X' +C[diagind(C)] .= Inf ## start u = zeros(N) -function Φ(u, μ, C, ε, S) - return norm(NNlib.relu.((u .* S) .+ (u' .* S) .- C))^2 / 2 - 2*ε * dot(μ, u) +function Φ(u, μ, C, ε, I, J) + return norm(NNlib.relu.(u[I] + u[J] - C))^2 / 2 - 2*ε * dot(μ, u) end armijo_max = 25 +δ = 1e-5 for it = 1:10 @info it - Csp = similar(S); - I, J, V = findnz(Csp); - for (i, j) in zip(I, J) Csp[i, j] = D[i] + D[j] - dot(X[i, :], X[j, :]) end - - γ = (u .* S) + (u' .* S) - Csp - σ = similar(S) - I, J, V = findnz(σ) - for (i, j) in zip(I, J) σ[i, j] = (γ[i, j] ≥ 0) end - dropzeros!(σ) - γ = relu.(γ) / ε - δ = 1e-5 - - G = σ + Diagonal(vec(sum(σ; dims = 2)) .+ δ) - b = -ε*(vec(sum(γ; dims = 2)) .- 1) + I, J, V = findnz(S) + Csp = similar(X, length(V)); + @inbounds for k = 1:length(V) Csp[k] = D[I[k]] + D[J[k]] - dot(X[I[k], :], X[J[k], :]) end + γ = u[I] + u[J] - Csp + σ = similar(Csp) + @. σ = γ ≥ 0 + @. γ = relu(γ) / ε + σ_sp = sparse(I, J, σ) + γ_sp = sparse(I, J, γ) + G = σ_sp + Diagonal(vec(sum(σ_sp; dims = 2)) .+ δ) + b = -ε*(vec(sum(γ_sp; dims = 2)) .- 1) δu = similar(u) fill!(δu, 0) cg!(δu, G, b) - - d = ε*sum(γ .* ((δu .* S) + (δu' .* S))) - 2ε*dot(δu, μ) + d = ε*dot(γ, δu[I] + δu[J]) - 2ε*dot(δu, μ) + @info d t = 1 - Φ0 = Φ(u, μ, Csp, ε, S) + Φ0 = Φ(u, μ, Csp, ε, I, J) armijo_counter = 0 - θ = alg.alg.θ - κ = alg.alg.κ + θ = alg.θ + κ = alg.κ while (armijo_counter < armijo_max) && - (Φ(u + t * δu, μ, Csp, ε, S) ≥ Φ0 + t * θ * d) + (Φ(u + t * δu, μ, Csp, ε, I, J) ≥ Φ0 + t * θ * d) t = κ * t armijo_counter += 1 end @@ -100,6 +137,6 @@ sum(γ; dims = 2) sparse(γ) - using Plots plot(heatmap(Array(γ)), heatmap(Array(π))) +""" From b309a67d7a630f0f12b94c4ba1d58a33d48d918a Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Wed, 28 Jun 2023 17:11:30 +1000 Subject: [PATCH 8/9] clean up --- src/quadratic_newton_symm.jl | 5 +- test.jl | 142 ----------------------------------- 2 files changed, 4 insertions(+), 143 deletions(-) delete mode 100644 test.jl diff --git a/src/quadratic_newton_symm.jl b/src/quadratic_newton_symm.jl index 9f04e201..415494d1 100644 --- a/src/quadratic_newton_symm.jl +++ b/src/quadratic_newton_symm.jl @@ -246,6 +246,10 @@ function solve!(solver::QuadraticOTSolver{<:SymmetricQuadraticOTNewtonAS}) u_prev = similar(solver.cache.u) copy!(u_prev, solver.cache.u) + # active-set method only works for symmetric case with uniform weights. + # verify this is indeed the case + if !all(maximum(μ) .== μ) throw(ArgumentError("Active set method only works for uniform weights.")) end + function check_convergence_dual(u, u_prev, atol, rtol) norm_diff = norm(u - u_prev) isconverged = @@ -385,7 +389,6 @@ end # interface - function quadreg(μ, C, ε, alg::SymmetricQuadraticOTNewton; kwargs...) solver = build_solver(μ, C, ε, alg; kwargs...) solve!(solver) diff --git a/test.jl b/test.jl deleted file mode 100644 index af7f7fd6..00000000 --- a/test.jl +++ /dev/null @@ -1,142 +0,0 @@ -using Revise -using Pkg -Pkg.activate(".") -using OptimalTransport -using LinearAlgebra -using SparseArrays -using Distances -using StatsBase -using NNlib -using LazyArrays -using IterativeSolvers - -N = 10_000 -d = 250 -μ_spt = randn(N, d) -X = μ_spt -μ = ones(N) -C = sum(X.^2 ; dims = 2)/2 .+ sum(X.^2 ; dims = 2)'/2 - X * X' -Cmean = mean(C) -C[diagind(C)] .= Inf -ENV["JULIA_DEBUG"] = "OptimalTransport" - -ε = 1.0 -π = sparse(quadreg(μ, C / Cmean, ε, OptimalTransport.SymmetricQuadraticOTNewton(); maxiter = 25)) - -nnz(π) / N^2 - -using NearestNeighborDescent -using NearestNeighbors - -function knn_adj(X, k) - # indices, _ = knn_matrices(nndescent(X, k, Euclidean())); - indices, _ = knn(KDTree(X), X, k); - A = spzeros(size(X, 2), size(X, 2)); - @inbounds for i = 1:size(A, 1) - A[i, i] = 1 - @inbounds for j in indices[i] - A[i, j] = 1 - end - end - return A -end - -S = knn_adj(X', 50); -S += sign.(mean([randPermutation(N) for _ in 1:50])); -S = sign.(S); -S[diagind(S)] .= 0; -dropzeros!(S); - -nnz(S) / N^2 - -ε = 1 -alg = OptimalTransport.SymmetricQuadraticOTNewtonAS(S) -γ = quadreg(μ, C / Cmean, ε, alg; maxiter = 100) - -sum(γ; dims = 2) -sum(γ; dims = 1) - -diag(γ) - -norm(π - γ, 1) - -D = sum(X.^2 ; dims = 2)/2 -alg = OptimalTransport.SymmetricQuadraticOTNewtonAS(S) -# form partial cost matrix -I, J, V = findnz(S) -C = D[I] + D[J]; @time @inbounds for k = 1:length(V) C[k] += -dot(X[I[k], :], X[J[k], :]) end -C_sp = sparse(I, J, relu.(C)) -solver = OptimalTransport.build_solver(μ, C_sp, ε, alg; maxiter = 50) -OptimalTransport.solve!(solver); -# form full cost matrix -Cfull = sum(X.^2 ; dims = 2)/2 .+ sum(X.^2 ; dims = 2)'/2 - X * X'; -Cfull[diagind(Cfull)] .= Inf; -plan = OptimalTransport.plan(solver, Cfull) -S = max.(S, sign.(plan)) - -norm(sum(plan; dims = 1) .- 1) -norm(sum(plan; dims = 2) .- 1) - -nnz(plan) / N^2 - -plot(heatmap(Array(plan)), heatmap(Array(π))) - -""" -D = sum(X.^2 ; dims = 2)/2 -Csp = D[I] + D[J]; @time @inbounds for k = 1:length(V) Csp[k] += -dot(X[I[k], :], X[J[k], :]) end -C = sum(X.^2 ; dims = 2)/2 .+ sum(X.^2 ; dims = 2)'/2 - X * X' -C[diagind(C)] .= Inf -## start -u = zeros(N) - -function Φ(u, μ, C, ε, I, J) - return norm(NNlib.relu.(u[I] + u[J] - C))^2 / 2 - 2*ε * dot(μ, u) -end - -armijo_max = 25 -δ = 1e-5 -for it = 1:10 - @info it - I, J, V = findnz(S) - Csp = similar(X, length(V)); - @inbounds for k = 1:length(V) Csp[k] = D[I[k]] + D[J[k]] - dot(X[I[k], :], X[J[k], :]) end - γ = u[I] + u[J] - Csp - σ = similar(Csp) - @. σ = γ ≥ 0 - @. γ = relu(γ) / ε - σ_sp = sparse(I, J, σ) - γ_sp = sparse(I, J, γ) - G = σ_sp + Diagonal(vec(sum(σ_sp; dims = 2)) .+ δ) - b = -ε*(vec(sum(γ_sp; dims = 2)) .- 1) - - δu = similar(u) - fill!(δu, 0) - cg!(δu, G, b) - - d = ε*dot(γ, δu[I] + δu[J]) - 2ε*dot(δu, μ) - @info d - t = 1 - Φ0 = Φ(u, μ, Csp, ε, I, J) - armijo_counter = 0 - θ = alg.θ - κ = alg.κ - while (armijo_counter < armijo_max) && - (Φ(u + t * δu, μ, Csp, ε, I, J) ≥ Φ0 + t * θ * d) - t = κ * t - armijo_counter += 1 - end - u .= u + t * δu -end -γ = relu.(u .+ u' - C) / ε -S = max.(S, sign.(γ)) -norm(γ - π, 1) / norm(γ, 1) - -sum(γ; dims = 1) -sum(γ; dims = 2) - -sparse(γ) - - -using Plots -plot(heatmap(Array(γ)), heatmap(Array(π))) -""" From bb223f64d328db423198bb0e5c36702bce374c71 Mon Sep 17 00:00:00 2001 From: Stephen Zhang Date: Thu, 17 Aug 2023 20:17:51 +1000 Subject: [PATCH 9/9] clean up --- src/quadratic_newton_symm.jl~ | 0 1 file changed, 0 insertions(+), 0 deletions(-) delete mode 100644 src/quadratic_newton_symm.jl~ diff --git a/src/quadratic_newton_symm.jl~ b/src/quadratic_newton_symm.jl~ deleted file mode 100644 index e69de29b..00000000