diff --git a/ext/TenetAdaptExt.jl b/ext/TenetAdaptExt.jl index 1331b86a4..c0e7d381b 100644 --- a/ext/TenetAdaptExt.jl +++ b/ext/TenetAdaptExt.jl @@ -9,5 +9,7 @@ Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tens Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites) Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), Tenet.lattice(x)) Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x))) +Adapt.adapt_structure(to, x::MPS) = MPS(adapt(to, Ansatz(x)), form(x)) +Adapt.adapt_structure(to, x::MPO) = MPO(adapt(to, Ansatz(x)), form(x)) end diff --git a/ext/TenetChainRulesCoreExt/frules.jl b/ext/TenetChainRulesCoreExt/frules.jl index d6bb3d43e..2fb918b4d 100644 --- a/ext/TenetChainRulesCoreExt/frules.jl +++ b/ext/TenetChainRulesCoreExt/frules.jl @@ -18,6 +18,8 @@ end # `AbstractAnsatz`-subtype constructors ChainRulesCore.frule((_, ẋ), ::Type{Product}, x::Ansatz) = Product(x), Tangent{Product}(; tn=ẋ) +ChainRulesCore.frule((_, ẋ), ::Type{MPS}, x::Ansatz, form) = MPS(x, form), Tangent{MPS}(; tn=ẋ, form=NoTangent()) +ChainRulesCore.frule((_, ẋ), ::Type{MPO}, x::Ansatz, form) = MPO(x, form), Tangent{MPO}(; tn=ẋ, form=NoTangent()) # `Base.conj` methods ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ) diff --git a/ext/TenetChainRulesCoreExt/rrules.jl b/ext/TenetChainRulesCoreExt/rrules.jl index a932dd062..fbe788fc7 100644 --- a/ext/TenetChainRulesCoreExt/rrules.jl +++ b/ext/TenetChainRulesCoreExt/rrules.jl @@ -25,6 +25,14 @@ Product_pullback(ȳ) = (NoTangent(), ȳ.tn) Product_pullback(ȳ::AbstractThunk) = Product_pullback(unthunk(ȳ)) ChainRulesCore.rrule(::Type{Product}, x::Ansatz) = Product(x), Product_pullback +MPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +MPS_pullback(ȳ::AbstractThunk) = MPS_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{MPS}, x::Ansatz, form) = MPS(x, form), MPS_pullback + +MPO_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent()) +MPO_pullback(ȳ::AbstractThunk) = MPO_pullback(unthunk(ȳ)) +ChainRulesCore.rrule(::Type{MPO}, x::Ansatz, form) = MPO(x, form), MPO_pullback + # `Base.conj` methods conj_pullback(Δ::Tensor) = (NoTangent(), conj(Δ)) conj_pullback(Δ::Tangent{Tensor}) = (NoTangent(), conj(Δ)) diff --git a/ext/TenetChainRulesTestUtilsExt.jl b/ext/TenetChainRulesTestUtilsExt.jl index 869465558..704152c63 100644 --- a/ext/TenetChainRulesTestUtilsExt.jl +++ b/ext/TenetChainRulesTestUtilsExt.jl @@ -33,6 +33,11 @@ function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Ansatz) end ChainRulesTestUtils.rand_tangent(::AbstractRNG, lattice::Tenet.Lattice) = NoTangent() +ChainRulesTestUtils.test_approx(::AbstractZero, form::Tenet.Lattice, msg=""; kwargs...) = true ChainRulesTestUtils.test_approx(actual::Tenet.Lattice, expected::Tenet.Lattice, msg; kwargs...) = actual == expected +ChainRulesTestUtils.rand_tangent(::AbstractRNG, form::Tenet.Form) = NoTangent() +ChainRulesTestUtils.test_approx(::AbstractZero, form::Tenet.Form, msg=""; kwargs...) = true +ChainRulesTestUtils.test_approx(actual::Tenet.Form, expected::Tenet.Form, msg; kwargs...) = actual == expected + end diff --git a/ext/TenetQuacExt.jl b/ext/TenetQuacExt.jl index 15f6141a5..6dd36b00d 100644 --- a/ext/TenetQuacExt.jl +++ b/ext/TenetQuacExt.jl @@ -3,13 +3,11 @@ module TenetQuacExt using Tenet using Quac: Gate, Circuit, lanes, arraytype, Swap -# function Tenet.Dense(gate::Gate) -# return Tenet.Dense( -# Operator(), arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...] -# ) -# end +function Tenet.Quantum(gate::Gate) + return Tenet.Quantum(arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...]) +end -# Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Tenet.Dense(gate); kwargs...) +Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Quantum(gate); kwargs...) function Tenet.Quantum(circuit::Circuit) n = lanes(circuit) diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 9e29248ca..520ffdc24 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -42,6 +42,13 @@ function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reac return Tenet.Product(tracetn) end +for A in (MPS, MPO) + @eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return $A(tracetn, form(prev)) + end +end + function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores) return :($Tensor($data, $(inds(tocopy)))) @@ -70,6 +77,13 @@ function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), resu return :($(Tenet.Product)($tn)) end +for A in (MPS, MPO) + @eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A} + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($A($tn, $(Tenet.form(tocopy)))) + end +end + # TODO try rely on generic fallback for ansatzes # function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores) # tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) diff --git a/src/Ansatz.jl b/src/Ansatz.jl index 5a71fd188..91354a1a1 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -46,7 +46,7 @@ struct NonCanonical <: Form end [`Form`](@ref) trait representing a [`AbstractAnsatz`](@ref) Tensor Network in mixed-canonical form. """ struct MixedCanonical <: Form - orthogonality_center::Union{Site,Vector{Site}} + orthog_center::Union{Site,Vector{Site}} end """ @@ -192,15 +192,52 @@ Contract the virtual bond between two [`Site`](@ref)s in a [`AbstractAnsatz`](@r """ @kwmethod contract!(tn::AbstractAnsatz; bond) = contract!(tn, inds(tn; bond)) +""" + canonize!(tn::AbstractAnsatz) + +Transform an [`AbstractAnsatz`](@ref) Tensor Network into the canonical form (aka Vidal gauge); i.e. the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ. +""" +function canonize! end + +""" + canonize(tn::AbstractAnsatz) + +Like [`canonize!`](@ref), but returns a new Tensor Network instead of modifying the original one. +""" canonize(tn::AbstractAnsatz, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...) + +""" + mixed_canonize!(tn::AbstractAnsatz, orthog_center) + +Transform an [`AbstractAnsatz`](@ref) Tensor Network into the mixed-canonical form, that is, +for `i < orthog_center` the tensors are left-canonical and for `i >= orthog_center` the tensors are right-canonical, +and in the `orthog_center` there is a tensor with the Schmidt coefficients in it. +""" +function mixed_canonize! end + +""" + mixed_canonize(tn::AbstractAnsatz, orthog_center) + +Like [`mixed_canonize!`](@ref), but returns a new Tensor Network instead of modifying the original one. +""" +mixed_canonize(tn::AbstractAnsatz, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...) + canonize_site(tn::AbstractAnsatz, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...) +""" + isisometry(tn::AbstractAnsatz, site; dir, kwargs...) + +Check if the tensor at a given [`Site`](@ref) in a [`AbstractAnsatz`](@ref) Tensor Network is an isometry. +The `dir` keyword argument specifies the direction of the isometry to check. +""" +function isisometry end + """ truncate(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing) -Like [`truncate!`](@ref), but returns a new tensor network instead of modifying the original one. +Like [`truncate!`](@ref), but returns a new Tensor Network instead of modifying the original one. """ -truncate(tn::AbstractAnsatz, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) +Base.truncate(tn::AbstractAnsatz, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...) """ truncate!(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing) @@ -236,15 +273,21 @@ Truncate the dimension of the virtual `bond` of a [`NonCanonical`](@ref) Tensor - `compute_local_svd`: Whether to compute the local SVD of the bond. If `true`, it will contract the bond and perform a SVD to get the local singular values. Defaults to `true`. """ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, compute_local_svd=true) + virtualind = inds(tn; bond) + if compute_local_svd + tₗ = tensors(tn; at=min(bond...)) + tᵣ = tensors(tn; at=max(bond...)) contract!(tn; bond) - svd!(tn; virtualind=inds(tn; bond)) + + left_inds = filter(!=(virtualind), inds(tₗ)) + right_inds = filter(!=(virtualind), inds(tᵣ)) + svd!(tn; left_inds, right_inds, virtualind=virtualind) end spectrum = parent(tensors(tn; bond)) - vind = inds(tn; bond) - maxdim = isnothing(maxdim) ? size(tn, vind) : maxdim + maxdim = isnothing(maxdim) ? size(tn, virtualind) : maxdim extent = if isnothing(threshold) 1:maxdim @@ -254,7 +297,7 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, end - 1, maxdim) end - slice!(tn, vind, extent) + slice!(tn, virtualind, extent) return tn end @@ -268,7 +311,7 @@ end function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim) truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false) # requires a sweep to recanonize the TN - return canonize!(tn, bond) + return canonize!(tn) end overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)')) @@ -333,7 +376,7 @@ function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=noth @assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites" - return simple_update!(form(ψ), ψ, gate; kwargs...) + return simple_update!(form(ψ), ψ, gate; threshold, maxdim, kwargs...) end # TODO a lot of problems with merging... maybe we shouldn't merge manually @@ -376,9 +419,19 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth rinds = filter(!=(vind), inds(tensors(ψ; at=siter))) contract!(ψ; bond) + # TODO replace for `merge!` when #243 is fixed + # reindex contracting indices to temporary names to avoid issues + oinds = Dict(site => inds(ψ; at=site) for site in sites(gate; set=:outputs)) + tmpinds = Dict(site => gensym(:tmp) for site in sites(gate; set=:inputs)) + replace!(gate, [inds(gate; at=site) => i for (site, i) in tmpinds]) + replace!(ψ, [inds(ψ; at=site') => i for (site, i) in tmpinds]) + + # NOTE `replace!` is getting confused when a index is already there even if it would be overriden + # TODO fix this to be handled in one call -> replace when #244 is fixed + replace!(gate, [inds(gate; at=site) => gensym() for (site, i) in oinds]) + replace!(gate, [inds(gate; at=site) => i for (site, i) in oinds]) + # contract physical inds with gate - @reindex! outputs(ψ) => outputs(gate) reset = false - @reindex! inputs(gate) => outputs(ψ) reset = false merge!(ψ, gate; reset=false) contract!(ψ, inds(gate; set=:inputs)) @@ -395,129 +448,8 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth end # TODO remove `renormalize` argument? -# TODO refactor code +# TODO optimize correctly -> avoid recanonization + use lateral Λs function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false) - @assert nlanes(gate) == 2 "Only 2-site gates are supported currently" - @assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites" - - # shallow copy to avoid problems if errors in mid execution - gate = copy(gate) - - bond = sitel, siter = minmax(sites(gate; set=:outputs)...) - left_inds::Vector{Symbol} = !isnothing(leftindex(ψ, sitel)) ? [leftindex(ψ, sitel)] : Symbol[] - right_inds::Vector{Symbol} = !isnothing(rightindex(ψ, siter)) ? [rightindex(ψ, siter)] : Symbol[] - - virtualind::Symbol = inds(ψ; bond=bond) - - contract_2sitewf!(ψ, bond) - - # reindex contracting index - contracting_inds = [gensym(:tmp) for _ in sites(gate; set=:inputs)] - replace!( - ψ, - map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) - inds(ψ; at=site') => contracting_index - end, - ) - replace!( - gate, - map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index) - inds(gate; at=site) => contracting_index - end, - ) - - # replace output indices of the gate for gensym indices - output_inds = [gensym(:out) for _ in sites(gate; set=:outputs)] - replace!( - gate, - map(zip(sites(gate; set=:outputs), output_inds)) do (site, out) - inds(gate; at=site) => out - end, - ) - - # reindex output of gate to match TN sitemap - for site in sites(gate; set=:outputs) - if inds(ψ; at=site) != inds(gate; at=site) - replace!(gate, inds(gate; at=site) => inds(ψ; at=site)) - end - end - - # contract physical inds - merge!(ψ, gate) - contract!(ψ, contracting_inds) - - # decompose using SVD - push!(left_inds, inds(ψ; at=sitel)) - push!(right_inds, inds(ψ; at=siter)) - - unpack_2sitewf!(ψ, bond, left_inds, right_inds, virtualind) - - # truncate virtual index - if any(!isnothing, [threshold, maxdim]) - truncate!(ψ, bond; threshold, maxdim) - renormalize && normalize!(tensors(ψ; between=bond)) - end - - return ψ -end - -# TODO refactor code -""" - contract_2sitewf!(ψ::AbstractAnsatz, bond) - -For a given [`AbstractAnsatz`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁, -where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ. -""" -function contract_2sitewf!(ψ::AbstractAnsatz, bond) - @assert form(ψ) == Canonical() "The tensor network must be in canonical form" - - sitel, siter = bond # TODO Check if bond is valid - (0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) || - throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - - Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) - Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) - - !isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false) - !isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false) - - contract!(ψ, inds(ψ; bond=bond)) - - return ψ -end - -# TODO refactor code -""" - unpack_2sitewf!(ψ::AbstractAnsatz, bond) - -For a given [`AbstractAnsatz`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical -form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`. -""" -function unpack_2sitewf!(ψ::AbstractAnsatz, bond, left_inds, right_inds, virtualind) - @assert form(ψ) == Canonical() "The tensor network must be in canonical form" - - sitel, siter = bond # TODO Check if bond is valid - (0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) || - throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))")) - - Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel)) - Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1))) - - # do svd of the θ tensor - θ = tensors(ψ; at=sitel) - U, s, Vt = svd(θ; left_inds, right_inds, virtualind) - - # contract with the inverse of Λᵢ and Λᵢ₊₂ - Γᵢ₋₁ = - isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=()) - Γᵢ = - isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=()) - - delete!(ψ, θ) - - push!(ψ, Γᵢ₋₁) - push!(ψ, s) - push!(ψ, Γᵢ) - - return ψ + simple_update!(NonCanonical(), ψ, gate; threshold, maxdim, renormalize) + return canonize!(ψ) end diff --git a/src/Lattice.jl b/src/Lattice.jl index 7b471bc82..04c002a74 100644 --- a/src/Lattice.jl +++ b/src/Lattice.jl @@ -29,8 +29,10 @@ end Base.copy(lattice::Lattice) = Lattice(copy(lattice.mapping), copy(lattice.graph)) Base.:(==)(a::Lattice, b::Lattice) = a.mapping == b.mapping && a.graph == b.graph +# TODO these where needed by ChainRulesTestUtils, do we still need them? Base.zero(::Type{Lattice}) = Lattice(BijectiveIdDict{Site,Int}(), zero(Graphs.SimpleGraph{Int})) Base.zero(::Lattice) = zero(Lattice) + Graphs.is_directed(::Type{Lattice}) = false function Graphs.vertices(lattice::Lattice) diff --git a/src/MPS.jl b/src/MPS.jl new file mode 100644 index 000000000..2a6b07992 --- /dev/null +++ b/src/MPS.jl @@ -0,0 +1,476 @@ +using Random +using LinearAlgebra +using Graphs +using BijectiveDicts: BijectiveIdDict + +abstract type AbstractMPO <: AbstractAnsatz end +abstract type AbstractMPS <: AbstractMPO end + +""" + MPS <: AbstractAnsatz + +A Matrix Product State [`Ansatz`](@ref) Tensor Network. +""" +mutable struct MPS <: AbstractMPS + const tn::Ansatz + form::Form +end + +""" + MPO <: AbstractAnsatz + +A Matrix Product Operator (MPO) [`Ansatz`](@ref) Tensor Network. +""" +mutable struct MPO <: AbstractMPO + const tn::Ansatz + form::Form +end + +Ansatz(tn::Union{MPS,MPO}) = tn.tn + +boundary(::Union{MPS,MPO}) = Open() +form(tn::Union{MPS,MPO}) = tn.form + +Base.copy(x::T) where {T<:Union{MPS,MPO}} = T(copy(Ansatz(x)), form(x)) +Base.similar(x::T) where {T<:Union{MPS,MPO}} = T(similar(Ansatz(x)), form(x)) +Base.zero(x::T) where {T<:Union{MPS,MPO}} = T(zero(Ansatz(x)), form(x)) + +defaultorder(::Type{<:AbstractMPS}) = (:o, :l, :r) +defaultorder(::Type{<:AbstractMPO}) = (:o, :i, :l, :r) + +""" + MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) + +Create a [`NonCanonical`](@ref) [`MPS`](@ref) from a vector of arrays. + +# Keyword Arguments + + - `order` The order of the indices in the arrays. Defaults to `(:o, :l, :r)`. +""" +function MPS(arrays::Vector{<:AbstractArray}; order=defaultorder(MPS)) + @assert ndims(arrays[1]) == 2 "First array must have 2 dimensions" + @assert all(==(3) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 3 dimensions" + @assert ndims(arrays[end]) == 2 "Last array must have 2 dimensions" + issetequal(order, defaultorder(MPS)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPS)))")) + + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(2n)] + + tn = TensorNetwork( + map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :r + symbols[n + mod1(i, n)] + elseif dir == :l + symbols[n + mod1(i - 1, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end, + ) + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + qtn = Quantum(tn, sitemap) + graph = path_graph(n) + mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))]) + lattice = Lattice(mapping, graph) + ansatz = Ansatz(qtn, lattice) + return MPS(ansatz, NonCanonical()) +end + +""" + MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) + +Create a [`NonCanonical`](@ref) [`MPO`](@ref) from a vector of arrays. + +# Keyword Arguments + + - `order` The order of the indices in the arrays. Defaults to `(:o, :i, :l, :r)`. +""" +function MPO(arrays::Vector{<:AbstractArray}; order=defaultorder(MPO)) + @assert ndims(arrays[1]) == 3 "First array must have 3 dimensions" + @assert all(==(4) ∘ ndims, arrays[2:(end - 1)]) "All arrays must have 4 dimensions" + @assert ndims(arrays[end]) == 3 "Last array must have 3 dimensions" + issetequal(order, defaultorder(MPO)) || + throw(ArgumentError("order must be a permutation of $(String.(defaultorder(MPO)))")) + + n = length(arrays) + gen = IndexCounter() + symbols = [nextindex!(gen) for _ in 1:(3n - 1)] + + tn = TensorNetwork( + map(enumerate(arrays)) do (i, array) + _order = if i == 1 + filter(x -> x != :l, order) + elseif i == n + filter(x -> x != :r, order) + else + order + end + + inds = map(_order) do dir + if dir == :o + symbols[i] + elseif dir == :i + symbols[i + n] + elseif dir == :l + symbols[2n + mod1(i - 1, n)] + elseif dir == :r + symbols[2n + mod1(i, n)] + else + throw(ArgumentError("Invalid direction: $dir")) + end + end + Tensor(array, inds) + end, + ) + + sitemap = Dict(Site(i) => symbols[i] for i in 1:n) + merge!(sitemap, Dict(Site(i; dual=true) => symbols[i + n] for i in 1:n)) + qtn = Quantum(tn, sitemap) + graph = path_graph(n) + mapping = BijectiveIdDict{Site,Int}(Pair{Site,Int}[site => i for (i, site) in enumerate(lanes(qtn))]) + lattice = Lattice(mapping, graph) + ansatz = Ansatz(qtn, lattice) + return MPO(ansatz, NonCanonical()) +end + +""" + MPS(::typeof(identity), n::Integer; physdim=2, maxdim=physdim^(n ÷ 2)) + +Returns an [`MPS`](@ref) of `n` sites whose tensors are initialized to COPY-tensors. + +# Keyword Arguments + + - `physdim` The physical or output dimension of each site. Defaults to 2. + - `maxdim` The maximum bond dimension. Defaults to `physdim^(n ÷ 2)`. +""" +function MPS(::typeof(identity), n::Integer; physdim=2, maxdim=physdim^(n ÷ 2)) + # Create bond dimensions until the middle of the MPS considering maxdim + virtualdims = min.(maxdim, physdim .^ (1:(n ÷ 2))) + + # Complete the bond dimensions of the other half of the MPS + virtualdims = vcat(virtualdims, virtualdims[(isodd(n) ? end : end - 1):-1:1]) + + # Create each site dimensions in default order (:o, :l, :r) + arraysdims = [[physdim, virtualdims[1]]] + append!(arraysdims, [[physdim, virtualdims[i], virtualdims[i + 1]] for i in 1:(length(virtualdims) - 1)]) + push!(arraysdims, [physdim, virtualdims[end]]) + + # Create the MPS with copy-tensors according to the tensors dimensions + return MPS( + map(arraysdims) do arrdims + arr = zeros(ComplexF64, arrdims...) + deltas = [fill(i, length(arrdims)) for i in 1:physdim] + broadcast(delta -> arr[delta...] = 1.0, deltas) + arr + end, + ) +end + +function Base.convert(::Type{T}, tn::Product) where {T<:AbstractMPO} + @assert socket(tn) == State() + + arrs::Vector{Array} = arrays(tn) + arrs[1] = reshape(arrs[1], size(arrs[1])..., 1) + arrs[end] = reshape(arrs[end], size(arrs[end])..., 1) + map!(@view(arrs[2:(end - 1)]), @view(arrs[2:(end - 1)])) do arr + reshape(arr, size(arr)..., 1, 1) + end + + return T(arrs) +end + +# TODO can this be better written? or even generalized to AbstractAnsatz? +Base.adjoint(tn::T) where {T<:AbstractMPO} = T(adjoint(Ansatz(tn)), form(tn)) + +# TODO different input/output physical dims +# TODO let choose the orthogonality center +# TODO add form information +""" + Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float64, physdim=2) + +Create a random [`MPS`](@ref) Tensor Network. +In order to avoid norm explosion issues, the tensors are orthogonalized by QR factorization so its normalized and mixed canonized to the last site. + +# Keyword Arguments + + - `n` The number of sites. + - `maxdim` The maximum bond dimension. + - `eltype` The element type of the tensors. Defaults to `Float64`. + - `physdim` The physical or output dimension of each site. Defaults to 2. +""" +function Base.rand(rng::Random.AbstractRNG, ::Type{MPS}; n, maxdim, eltype=Float64, physdim=2) + p = physdim + T = eltype + χ = maxdim + + arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 + χl = min(χ, p^(i - 1)) + χr = min(χ, p^i) + + # swap bond dims after mid and handle midpoint for odd-length MPS + (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) + end + + # orthogonalize by QR factorization + F = lq!(rand(rng, T, χl, p * χr)) + + reshape(Matrix(F.Q), χl, p, χr) + end + + # reshape boundary sites + arrays[1] = reshape(arrays[1], p, p) + arrays[n] = reshape(arrays[n], p, p) + + return MPS(arrays; order=(:l, :o, :r)) +end + +# TODO different input/output physical dims +# TODO let choose the orthogonality center +""" + Base.rand(rng::Random.AbstractRNG, ::Type{MPO}; n, maxdim, eltype=Float64, physdim=2) + +Create a random [`MPO`](@ref) Tensor Network. +In order to avoid norm explosion issues, the tensors are orthogonalized by QR factorization so its normalized and mixed canonized to the last site. + +# Keyword Arguments + + - `n` The number of sites. + - `maxdim` The maximum bond dimension. + - `eltype` The element type of the tensors. Defaults to `Float64`. + - `physdim` The physical or output dimension of each site. Defaults to 2. +""" +function Base.rand(rng::Random.AbstractRNG, ::Type{MPO}; n, maxdim, eltype=Float64, physdim=2) + T = eltype + ip = op = physdim + χ = maxdim + + arrays::Vector{AbstractArray{T,N} where {N}} = map(1:n) do i + χl, χr = let after_mid = i > n ÷ 2, i = (n + 1 - abs(2i - n - 1)) ÷ 2 + χl = min(χ, ip^(i - 1) * op^(i - 1)) + χr = min(χ, ip^i * op^i) + + # swap bond dims after mid and handle midpoint for odd-length MPS + (isodd(n) && i == n ÷ 2 + 1) ? (χl, χl) : (after_mid ? (χr, χl) : (χl, χr)) + end + + # orthogonalize by QR factorization + F = lq!(rand(rng, T, χl, ip * op * χr)) + reshape(Matrix(F.Q), χl, ip, op, χr) + end + + # reshape boundary sites + arrays[1] = reshape(arrays[1], ip, op, min(χ, ip * op)) + arrays[n] = reshape(arrays[n], min(χ, ip * op), ip, op) + + # TODO order might not be the best for performance + return MPO(arrays; order=(:l, :i, :o, :r)) +end + +# TODO deprecate contract(; between) and generalize it to AbstractAnsatz +""" + Tenet.contract!(tn::AbstractMPO; between=(site1, site2), direction::Symbol = :left, delete_Λ = true) + +For a given [`AbstractMPO`](@ref) Tensor Network, contract the singular values Λ between two sites `site1` and `site2`. +The `direction` keyword argument specifies the direction of the contraction, and the `delete_Λ` keyword argument +specifies whether to delete the singular values tensor after the contraction. +""" +@kwmethod contract(tn::AbstractMPO; between, direction, delete_Λ) = contract!(copy(tn); between, direction, delete_Λ) +@kwmethod function contract!(tn::AbstractMPO; between, direction, delete_Λ) + site1, site2 = between + Λᵢ = tensors(tn; between) + Λᵢ === nothing && return tn + + if direction === :right + Γᵢ₊₁ = tensors(tn; at=site2) + replace!(tn, Γᵢ₊₁ => contract(Γᵢ₊₁, Λᵢ; dims=())) + elseif direction === :left + Γᵢ = tensors(tn; at=site1) + replace!(tn, Γᵢ => contract(Λᵢ, Γᵢ; dims=())) + else + throw(ArgumentError("Unknown direction=:$direction")) + end + + delete_Λ && delete!(TensorNetwork(tn), Λᵢ) + + return tn +end +@kwmethod contract(tn::AbstractMPO; between) = contract(tn; between, direction=:left, delete_Λ=true) +@kwmethod contract!(tn::AbstractMPO; between) = contract!(tn; between, direction=:left, delete_Λ=true) +@kwmethod contract(tn::AbstractMPO; between, direction) = contract(tn; between, direction, delete_Λ=true) +@kwmethod contract!(tn::AbstractMPO; between, direction) = contract!(tn; between, direction, delete_Λ=true) + +# TODO change it to `lanes`? +# TODO refactor to use `Lattice` +function sites(ψ::T, site::Site; dir) where {T<:AbstractMPO} + if dir === :left + return site <= site"1" ? nothing : Site(id(site) - 1) + elseif dir === :right + return site >= Site(nlanes(ψ)) ? nothing : Site(id(site) + 1) + else + throw(ArgumentError("Unknown direction for $T = :$dir")) + end +end + +# TODO refactor to use `Lattice` +@kwmethod function inds(ψ::T; at, dir) where {T<:AbstractMPO} + if dir === :left && at == site"1" + return nothing + elseif dir === :right && at == Site(nlanes(ψ); dual=isdual(at)) + return nothing + elseif dir ∈ (:left, :right) + return inds(ψ; bond=(at, sites(ψ, at; dir))) + else + throw(ArgumentError("Unknown direction for $T = :$dir")) + end +end + +function isisometry(ψ::AbstractMPO, site; dir, atol::Real=1e-12) + tensor = tensors(ψ; at=site) + dirind = inds(ψ; at=site, dir) + + if isnothing(dirind) + return isapprox(parent(contract(tensor, conj(tensor))), fill(true); atol) + end + + inda, indb = gensym(:a), gensym(:b) + a = replace(tensor, dirind => inda) + b = replace(conj(tensor), dirind => indb) + + n = size(tensor, dirind) + contracted = contract(a, b; out=[inda, indb]) + + return isapprox(contracted, I(n); atol) +end + +@deprecate isleftcanonical(ψ::AbstractMPO, site; atol::Real=1e-12) isisometry(ψ, site; dir=:right, atol) +@deprecate isrightcanonical(ψ::AbstractMPO, site; atol::Real=1e-12) isisometry(ψ, site; dir=:left, atol) + +# TODO generalize to AbstractAnsatz +# NOTE: in method == :svd the spectral weights are stored in a vector connected to the now virtual hyperindex! +function canonize_site!(ψ::MPS, site::Site; direction::Symbol, method=:qr) + left_inds = Symbol[] + right_inds = Symbol[] + + virtualind = if direction === :left + site == Site(1) && throw(ArgumentError("Cannot right-canonize left-most tensor")) + push!(right_inds, inds(ψ; at=site, dir=:left)) + + site == Site(nsites(ψ)) || push!(left_inds, inds(ψ; at=site, dir=:right)) + push!(left_inds, inds(ψ; at=site)) + + only(right_inds) + elseif direction === :right + site == Site(nsites(ψ)) && throw(ArgumentError("Cannot left-canonize right-most tensor")) + push!(right_inds, inds(ψ; at=site, dir=:right)) + + site == Site(1) || push!(left_inds, inds(ψ; at=site, dir=:left)) + push!(left_inds, inds(ψ; at=site)) + + only(right_inds) + else + throw(ArgumentError("Unknown direction=:$direction")) + end + + tmpind = gensym(:tmp) + if method === :svd + svd!(ψ; left_inds, right_inds, virtualind=tmpind) + elseif method === :qr + qr!(ψ; left_inds, right_inds, virtualind=tmpind) + else + throw(ArgumentError("Unknown factorization method=:$method")) + end + + contract!(ψ, virtualind) + replace!(ψ, tmpind => virtualind) + + return ψ +end + +function canonize!(ψ::AbstractMPO) + Λ = Tensor[] + + # right-to-left QR sweep, get right-canonical tensors + for i in nsites(ψ):-1:2 + canonize_site!(ψ, Site(i); direction=:left, method=:qr) + end + + # left-to-right SVD sweep, get left-canonical tensors and singular values without reversing + for i in 1:(nsites(ψ) - 1) + canonize_site!(ψ, Site(i); direction=:right, method=:svd) + + # extract the singular values and contract them with the next tensor + Λᵢ = pop!(ψ, tensors(ψ; between=(Site(i), Site(i + 1)))) + Aᵢ₊₁ = tensors(ψ; at=Site(i + 1)) + replace!(ψ, Aᵢ₊₁ => contract(Aᵢ₊₁, Λᵢ; dims=())) + push!(Λ, Λᵢ) + end + + for i in 2:nsites(ψ) # tensors at i in "A" form, need to contract (Λᵢ)⁻¹ with A to get Γᵢ + Λᵢ = Λ[i - 1] # singular values start between site 1 and 2 + A = tensors(ψ; at=Site(i)) + Γᵢ = contract(A, Tensor(diag(pinv(Diagonal(parent(Λᵢ)); atol=1e-64)), inds(Λᵢ)); dims=()) + replace!(ψ, A => Γᵢ) + push!(ψ, Λᵢ) + end + + ψ.form = Canonical() + + return ψ +end + +# TODO mixed_canonize! at bond +# TODO dispatch on form +# TODO generalize to AbstractAnsatz +function mixed_canonize!(tn::AbstractMPO, orthog_center) + # left-to-right QR sweep (left-canonical tensors) + for i in 1:(id(orthog_center) - 1) + canonize_site!(tn, Site(i); direction=:right, method=:qr) + end + + # right-to-left QR sweep (right-canonical tensors) + for i in nsites(tn):-1:(id(orthog_center) + 1) + canonize_site!(tn, Site(i); direction=:left, method=:qr) + end + + # center SVD sweep to get singular values + # canonize_site!(tn, orthog_center; direction=:left, method=:svd) + + tn.form = MixedCanonical(orthog_center) + + return tn +end + +LinearAlgebra.normalize!(ψ::AbstractMPO; kwargs...) = normalize!(form(ψ), ψ; kwargs...) + +function LinearAlgebra.normalize!(::NonCanonical, ψ::AbstractMPO; at=Site(nsites(ψ) ÷ 2)) + tensor = tensors(ψ; at) + tensor ./= norm(ψ) + return ψ +end + +LinearAlgebra.normalize!(ψ::AbstractMPO, site::Site) = normalize!(mixed_canonize!(ψ, site); at=site) + +function LinearAlgebra.normalize!(config::MixedCanonical, ψ::AbstractMPO; at=config.orthog_center) + mixed_canonize!(ψ, at) + normalize!(tensors(ψ; at), 2) + return ψ +end + +# TODO function LinearAlgebra.normalize!(::Canonical, ψ::AbstractMPO) end diff --git a/src/Quantum.jl b/src/Quantum.jl index e9f98b00e..0fd3d90f8 100644 --- a/src/Quantum.jl +++ b/src/Quantum.jl @@ -73,6 +73,29 @@ end Quantum(tn::Quantum) = tn +""" + Quantum(array, sites) + +Constructs a [`Quantum`](@ref) Tensor Network from an array and a list of sites. Useful for simple operators like gates. +""" +function Quantum(array, sites) + if ndims(array) != length(sites) + throw(ArgumentError("Number of sites must match number of dimensions of array")) + end + + gen = IndexCounter() + symbols = map(_ -> nextindex!(gen), sites) + sitemap = Dict{Site,Symbol}( + map(sites, 1:ndims(array)) do site, i + site => symbols[i] + end, + ) + tensor = Tensor(array, symbols) + tn = TensorNetwork([tensor]) + qtn = Quantum(tn, sitemap) + return qtn +end + """ TensorNetwork(q::AbstractQuantum) @@ -166,23 +189,34 @@ function Base.replace!(tn::AbstractQuantum, old_new::Base.AbstractVecOrTuple{Pai end function reindex!(a::Quantum, ioa, b::Quantum, iob; reset=true) - ioa ∈ [:inputs, :outputs] || error("Invalid argument: :$ioa") - if reset resetindex!(a) resetindex!(b; init=ninds(TensorNetwork(a)) + 1) end + sitesa = if ioa === :inputs + collect(sites(a; set=:inputs)) + elseif ioa === :outputs + collect(sites(a; set=:outputs)) + else + error("Invalid argument: $(Meta.quot(ioa))") + end + sitesb = if iob === :inputs collect(sites(b; set=:inputs)) elseif iob === :outputs collect(sites(b; set=:outputs)) else - error("Invalid argument: :$iob") + error("Invalid argument: :$(Meta.quot(iob))") end - replacements = map(sitesb) do site - inds(b; at=site) => inds(a; at=ioa != iob ? site' : site) + # TODO select sites to reindex + targetsites = (ioa === :inputs ? adjoint.(sitesa) : sitesa) ∩ (iob === :inputs ? adjoint.(sitesb) : sitesb) + + replacements = map(targetsites) do site + siteb = iob === :inputs ? site' : site + sitea = ioa === :inputs ? site' : site + inds(b; at=siteb) => inds(a; at=sitea) end if issetequal(first.(replacements), last.(replacements)) diff --git a/src/Tenet.jl b/src/Tenet.jl index 295fffc33..820be50ad 100644 --- a/src/Tenet.jl +++ b/src/Tenet.jl @@ -29,14 +29,18 @@ include("Ansatz.jl") export Ansatz export socket, Scalar, State, Operator export boundary, Open, Periodic -export form +export form, NonCanonical, MixedCanonical, Canonical -export canonize_site, canonize_site!, canonize, canonize!, mixed_canonize, mixed_canonize!, truncate! -export evolve!, expect, overlap +export canonize, canonize!, mixed_canonize, mixed_canonize!, truncate!, isisometry include("Product.jl") export Product +include("MPS.jl") +export MPS, MPO + +export evolve!, expect, overlap + # reexports from EinExprs export einexpr, inds diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index b2bdfa761..1a215a752 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -746,6 +746,10 @@ function Base.rand(::Type{TensorNetwork}, n::Integer, regularity::Integer; kwarg return rand(Random.default_rng(), TensorNetwork, n, regularity; kwargs...) end +function Base.rand(::Type{T}, args...; kwargs...) where {T<:AbstractTensorNetwork} + return rand(Random.default_rng(), T, args...; kwargs...) +end + function Serialization.serialize(s::AbstractSerializer, obj::TensorNetwork) Serialization.writetag(s.io, Serialization.OBJECT_TAG) serialize(s, TensorNetwork) diff --git a/test/Chain_test.jl b/test/Chain_test.jl deleted file mode 100644 index b76f70016..000000000 --- a/test/Chain_test.jl +++ /dev/null @@ -1,436 +0,0 @@ -@testset_skip "Chain ansatz" begin - @testset "Periodic boundary" begin - @testset "State" begin - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - arrays = [rand(2, 1, 4), rand(2, 4, 3), rand(2, 3, 1)] - qtn = Chain(State(), Periodic(), arrays) # Default order (:o, :l, :r) - - @test size(tensors(qtn; at=Site(1))) == (2, 1, 4) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 3, 1) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - arrays = [permutedims(array, (3, 1, 2)) for array in arrays] # now we have (:r, :o, :l) - qtn = Chain(State(), Periodic(), arrays; order=[:r, :o, :l]) - - @test size(tensors(qtn; at=Site(1))) == (4, 2, 1) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 4) - @test size(tensors(qtn; at=Site(3))) == (1, 2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - for i in 1:nsites(qtn) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - end - end - - @testset "Operator" begin - qtn = Chain(Operator(), Periodic(), [rand(2, 2, 4, 4) for _ in 1:3]) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == 3 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Periodic() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") != nothing - - arrays = [rand(2, 4, 1, 3), rand(2, 4, 3, 6), rand(2, 4, 6, 1)] # Default order (:o, :i, :l, :r) - qtn = Chain(Operator(), Periodic(), arrays) - - @test size(tensors(qtn; at=Site(1))) == (2, 4, 1, 3) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 3, 6) - @test size(tensors(qtn; at=Site(3))) == (2, 4, 6, 1) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - - arrays = [permutedims(array, (4, 1, 3, 2)) for array in arrays] # now we have (:r, :o, :l, :i) - qtn = Chain(Operator(), Periodic(), arrays; order=[:r, :o, :l, :i]) - - @test size(tensors(qtn; at=Site(1))) == (3, 2, 1, 4) - @test size(tensors(qtn; at=Site(2))) == (6, 2, 3, 4) - @test size(tensors(qtn; at=Site(3))) == (1, 2, 6, 4) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) !== nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - end - end - - @testset "Open boundary" begin - @testset "State" begin - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] - qtn = Chain(State(), Open(), arrays) # Default order (:o, :l, :r) - - @test size(tensors(qtn; at=Site(1))) == (2, 1) - @test size(tensors(qtn; at=Site(2))) == (2, 1, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) - - arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) - qtn = Chain(State(), Open(), arrays; order=[:r, :o, :l]) - - @test size(tensors(qtn; at=Site(1))) == (1, 2) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 1) - @test size(tensors(qtn; at=Site(3))) == (2, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:nsites(qtn) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - end - end - @testset "Operator" begin - qtn = Chain(Operator(), Open(), [rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == 3 - @test nsites(qtn; set=:outputs) == 3 - @test issetequal(sites(qtn), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) - @test boundary(qtn) == Open() - @test leftindex(qtn, site"1") == rightindex(qtn, site"3") == nothing - - arrays = [rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)] # Default order (:o :i, :l, :r) - qtn = Chain(Operator(), Open(), arrays) - - @test size(tensors(qtn; at=Site(1))) == (2, 4, 1) - @test size(tensors(qtn; at=Site(2))) == (2, 4, 1, 3) - @test size(tensors(qtn; at=Site(3))) == (2, 4, 3) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - - arrays = [ - permutedims(arrays[1], (3, 1, 2)), - permutedims(arrays[2], (4, 1, 3, 2)), - permutedims(arrays[3], (1, 3, 2)), - ] # now we have (:r, :o, :l, :i) - qtn = Chain(Operator(), Open(), arrays; order=[:r, :o, :l, :i]) - - @test size(tensors(qtn; at=Site(1))) == (1, 2, 4) - @test size(tensors(qtn; at=Site(2))) == (3, 2, 1, 4) - @test size(tensors(qtn; at=Site(3))) == (2, 3, 4) - - @test leftindex(qtn, Site(1)) == rightindex(qtn, Site(3)) === nothing - @test leftindex(qtn, Site(2)) == rightindex(qtn, Site(1)) !== nothing - @test leftindex(qtn, Site(3)) == rightindex(qtn, Site(2)) !== nothing - - for i in 1:length(arrays) - @test size(qtn, inds(qtn; at=Site(i))) == 2 - @test size(qtn, inds(qtn; at=Site(i; dual=true))) == 4 - end - end - end - - @testset "Site" begin - using Tenet: leftsite, rightsite - qtn = Chain(State(), Periodic(), [rand(2, 4, 4) for _ in 1:3]) - - @test leftsite(qtn, Site(1)) == Site(3) - @test leftsite(qtn, Site(2)) == Site(1) - @test leftsite(qtn, Site(3)) == Site(2) - - @test rightsite(qtn, Site(1)) == Site(2) - @test rightsite(qtn, Site(2)) == Site(3) - @test rightsite(qtn, Site(3)) == Site(1) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - - @test isnothing(leftsite(qtn, Site(1))) - @test isnothing(rightsite(qtn, Site(3))) - - @test leftsite(qtn, Site(2)) == Site(1) - @test leftsite(qtn, Site(3)) == Site(2) - - @test rightsite(qtn, Site(2)) == Site(3) - @test rightsite(qtn, Site(1)) == Site(2) - end - - @testset "truncate" begin - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2)]) - canonize_site!(qtn, Site(2); direction=:right, method=:svd) - - @test_throws Tenet.MissingSchmidtCoefficientsException truncate!(qtn, [Site(1), Site(2)]; maxdim=1) - # @test_throws ArgumentError truncate!(qtn, [Site(2), Site(3)]) - - truncated = Tenet.truncate(qtn, [Site(2), Site(3)]; maxdim=1) - @test size(truncated, rightindex(truncated, Site(2))) == 1 - @test size(truncated, leftindex(truncated, Site(3))) == 1 - - singular_values = tensors(qtn; between=(Site(2), Site(3))) - truncated = Tenet.truncate(qtn, [Site(2), Site(3)]; threshold=singular_values[2] + 0.1) - @test size(truncated, rightindex(truncated, Site(2))) == 1 - @test size(truncated, leftindex(truncated, Site(3))) == 1 - end - - @testset "rand" begin - using LinearAlgebra: norm - - @testset "State" begin - n = 8 - χ = 10 - - qtn = rand(Chain, Open, State; n, p=2, χ) - @test socket(qtn) == State() - @test nsites(qtn; set=:inputs) == 0 - @test nsites(qtn; set=:outputs) == n - @test issetequal(sites(qtn), map(Site, 1:n)) - @test boundary(qtn) == Open() - @test isapprox(norm(qtn), 1.0) - @test maximum(last, size(qtn)) <= χ - end - - @testset "Operator" begin - n = 8 - χ = 10 - - qtn = rand(Chain, Open, Operator; n, p=2, χ) - @test socket(qtn) == Operator() - @test nsites(qtn; set=:inputs) == n - @test nsites(qtn; set=:outputs) == n - @test issetequal(sites(qtn), vcat(map(Site, 1:n), map(adjoint ∘ Site, 1:n))) - @test boundary(qtn) == Open() - @test isapprox(norm(qtn), 1.0) - @test maximum(last, size(qtn)) <= χ - end - end - - @testset "Canonization" begin - using Tenet - - @testset "contract" begin - qtn = rand(Chain, Open, State; n=5, p=2, χ=20) - let canonized = canonize(qtn) - @test_throws ArgumentError contract!(canonized; between=(Site(1), Site(2)), direction=:dummy) - end - - canonized = canonize(qtn) - - for i in 1:4 - contract_some = contract(canonized; between=(Site(i), Site(i + 1))) - Bᵢ = tensors(contract_some; at=Site(i)) - - @test isapprox(contract(contract_some), contract(qtn)) - @test_throws ArgumentError tensors(contract_some; between=(Site(i), Site(i + 1))) - - @test isrightcanonical(contract_some, Site(i)) - @test isleftcanonical( - contract(canonized; between=(Site(i), Site(i + 1)), direction=:right), Site(i + 1) - ) - - Γᵢ = tensors(canonized; at=Site(i)) - Λᵢ₊₁ = tensors(canonized; between=(Site(i), Site(i + 1))) - @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims=()) - end - end - - @testset "canonize_site" begin - qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4)]) - - @test_throws ArgumentError canonize_site!(qtn, Site(1); direction=:left) - @test_throws ArgumentError canonize_site!(qtn, Site(3); direction=:right) - - for method in [:qr, :svd] - canonized = canonize_site(qtn, site"1"; direction=:right, method=method) - @test isleftcanonical(canonized, site"1") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"2"; direction=:right, method=method) - @test isleftcanonical(canonized, site"2") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"2"; direction=:left, method=method) - @test isrightcanonical(canonized, site"2") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - - canonized = canonize_site(qtn, site"3"; direction=:left, method=method) - @test isrightcanonical(canonized, site"3") - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - end - - # Ensure that svd creates a new tensor - @test length(tensors(canonize_site(qtn, Site(2); direction=:left, method=:svd))) == 4 - end - - @testset "canonize" begin - using Tenet: isleftcanonical, isrightcanonical - - qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = canonize(qtn) - - @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - @test isapprox(norm(qtn), norm(canonized)) - - # Extract the singular values between each adjacent pair of sites in the canonized chain - Λ = [tensors(canonized; between=(Site(i), Site(i + 1))) for i in 1:4] - @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 - - for i in 1:5 - canonized = canonize(qtn) - - if i == 1 - @test isleftcanonical(canonized, Site(i)) - elseif i == 5 # in the limits of the chain, we get the norm of the state - contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) - tensor = tensors(canonized; at=Site(i)) - replace!(canonized, tensor => tensor / norm(canonized)) - @test isleftcanonical(canonized, Site(i)) - else - contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) - @test isleftcanonical(canonized, Site(i)) - end - end - - for i in 1:5 - canonized = canonize(qtn) - - if i == 1 # in the limits of the chain, we get the norm of the state - contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) - tensor = tensors(canonized; at=Site(i)) - replace!(canonized, tensor => tensor / norm(canonized)) - @test isrightcanonical(canonized, Site(i)) - elseif i == 5 - @test isrightcanonical(canonized, Site(i)) - else - contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) - @test isrightcanonical(canonized, Site(i)) - end - end - end - - @testset "mixed_canonize" begin - qtn = Chain(State(), Open(), [rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - canonized = mixed_canonize(qtn, Site(3)) - - @test length(tensors(canonized)) == length(tensors(qtn)) + 1 - - @test isleftcanonical(canonized, Site(1)) - @test isleftcanonical(canonized, Site(2)) - @test isrightcanonical(canonized, Site(3)) - @test isrightcanonical(canonized, Site(4)) - @test isrightcanonical(canonized, Site(5)) - - @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(qtn)) - end - end - - @test begin - qtn = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) - normalize!(qtn, Site(3)) - isapprox(norm(qtn), 1.0) - end - - @testset "adjoint" begin - qtn = rand(Chain, Open, State; n=5, p=2, χ=10) - adjoint_qtn = adjoint(qtn) - - for i in 1:nsites(qtn) - i < nsites(qtn) && - @test rightindex(adjoint_qtn, Site(i; dual=true)) == Symbol(String(rightindex(qtn, Site(i))) * "'") - i > 1 && @test leftindex(adjoint_qtn, Site(i; dual=true)) == Symbol(String(leftindex(qtn, Site(i))) * "'") - end - - @test isapprox(contract(qtn), contract(adjoint_qtn)) - end - - @testset "evolve!" begin - @testset "one site" begin - i = 2 - mat = reshape(LinearAlgebra.I(2), 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(i; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @testset "canonical form" begin - canonized = canonize(qtn) - - evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) - @test isapprox(contract(evolved), contract(canonized)) - @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - - @testset "arbitrary chain" begin - evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false) - @test length(tensors(evolved)) == 5 - @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - end - - @testset "two sites" begin - i, j = 2, 3 - mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @testset "canonical form" begin - canonized = canonize(qtn) - - evolved = evolve!(deepcopy(canonized), gate; threshold=1e-14) - @test isapprox(contract(evolved), contract(canonized)) - @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - - @testset "arbitrary chain" begin - evolved = evolve!(deepcopy(qtn), gate; threshold=1e-14, iscanonical=false) - @test length(tensors(evolved)) == 5 - @test issetequal(size.(tensors(evolved)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)]) - @test isapprox(contract(evolved), contract(qtn)) - end - end - end - - @testset "expect" begin - i, j = 2, 3 - mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) - gate = Dense(Tenet.Operator(), mat; sites=[Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) - - qtn = Chain(State(), Open(), [rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) - - @test isapprox(expect(qtn, [gate]), norm(qtn)^2) - end -end diff --git a/test/MPO_test.jl b/test/MPO_test.jl new file mode 100644 index 000000000..258c44679 --- /dev/null +++ b/test/MPO_test.jl @@ -0,0 +1,77 @@ +@testset "MPO" begin + H = MPO([rand(2, 2, 4), rand(2, 2, 4, 4), rand(2, 2, 4)]) + @test socket(H) == Operator() + @test nsites(H; set=:inputs) == 3 + @test nsites(H; set=:outputs) == 3 + @test issetequal(sites(H), [site"1", site"2", site"3", site"1'", site"2'", site"3'"]) + @test boundary(H) == Open() + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) == nothing + + # Default order (:o :i, :l, :r) + H = MPO([rand(2, 4, 1), rand(2, 4, 1, 3), rand(2, 4, 3)]) + + @test size(tensors(H; at=site"1")) == (2, 4, 1) + @test size(tensors(H; at=site"2")) == (2, 4, 1, 3) + @test size(tensors(H; at=site"3")) == (2, 4, 3) + + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) === nothing + @test inds(H; at=site"2", dir=:left) == inds(H; at=site"1", dir=:right) !== nothing + @test inds(H; at=site"3", dir=:left) == inds(H; at=site"2", dir=:right) !== nothing + + for i in 1:Tenet.ntensors(H) + @test size(H, inds(H; at=Site(i))) == 2 + @test size(H, inds(H; at=Site(i; dual=true))) == 4 + end + + # now we have (:r, :o, :l, :i) + H = MPO( + [ + permutedims(arrays(H)[1], (3, 1, 2)), + permutedims(arrays(H)[2], (4, 1, 3, 2)), + permutedims(arrays(H)[3], (1, 3, 2)), + ]; + order=[:r, :o, :l, :i], + ) + + @test size(tensors(H; at=site"1")) == (1, 2, 4) + @test size(tensors(H; at=site"2")) == (3, 2, 1, 4) + @test size(tensors(H; at=site"3")) == (2, 3, 4) + + @test inds(H; at=site"1", dir=:left) == inds(H; at=site"3", dir=:right) === nothing + @test inds(H; at=site"2", dir=:left) == inds(H; at=site"1", dir=:right) !== nothing + @test inds(H; at=site"3", dir=:left) == inds(H; at=site"2", dir=:right) !== nothing + + for i in 1:Tenet.ntensors(H) + @test size(H, inds(H; at=Site(i))) == 2 + @test size(H, inds(H; at=Site(i; dual=true))) == 4 + end + + @testset "Site" begin + H = MPO([rand(2, 2, 2), rand(2, 2, 2, 2), rand(2, 2, 2)]) + + @test isnothing(sites(H, site"1"; dir=:left)) + @test isnothing(sites(H, site"3"; dir=:right)) + + @test sites(H, site"2"; dir=:left) == site"1" + @test sites(H, site"3"; dir=:left) == site"2" + + @test sites(H, site"2"; dir=:right) == site"3" + @test sites(H, site"1"; dir=:right) == site"2" + end + + @testset "norm" begin + using LinearAlgebra: norm + + n = 8 + χ = 10 + H = rand(MPO; n, maxdim=χ) + + @test socket(H) == Operator() + @test nsites(H; set=:inputs) == n + @test nsites(H; set=:outputs) == n + @test issetequal(sites(H), vcat(map(Site, 1:n), map(adjoint ∘ Site, 1:n))) + @test boundary(H) == Open() + @test isapprox(norm(H), 1.0) + @test maximum(last, size(H)) <= χ + end +end diff --git a/test/MPS_test.jl b/test/MPS_test.jl new file mode 100644 index 000000000..86db308c7 --- /dev/null +++ b/test/MPS_test.jl @@ -0,0 +1,305 @@ +using Tenet: canonize_site, canonize_site! +using LinearAlgebra + +@testset "MPS" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + @test socket(ψ) == State() + @test nsites(ψ; set=:inputs) == 0 + @test nsites(ψ; set=:outputs) == 3 + @test issetequal(sites(ψ), [site"1", site"2", site"3"]) + @test boundary(ψ) == Open() + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) == nothing + + arrays = [rand(2, 1), rand(2, 1, 3), rand(2, 3)] + ψ = MPS(arrays) # Default order (:o, :l, :r) + @test size(tensors(ψ; at=site"1")) == (2, 1) + @test size(tensors(ψ; at=site"2")) == (2, 1, 3) + @test size(tensors(ψ; at=site"3")) == (2, 3) + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) === nothing + @test inds(ψ; at=site"2", dir=:left) == inds(ψ; at=site"1", dir=:right) + @test inds(ψ; at=site"3", dir=:left) == inds(ψ; at=site"2", dir=:right) + + arrays = [permutedims(arrays[1], (2, 1)), permutedims(arrays[2], (3, 1, 2)), permutedims(arrays[3], (1, 2))] # now we have (:r, :o, :l) + ψ = MPS(arrays; order=[:r, :o, :l]) + @test size(tensors(ψ; at=site"1")) == (1, 2) + @test size(tensors(ψ; at=site"2")) == (3, 2, 1) + @test size(tensors(ψ; at=site"3")) == (2, 3) + @test inds(ψ; at=site"1", dir=:left) == inds(ψ; at=site"3", dir=:right) === nothing + @test inds(ψ; at=site"2", dir=:left) == inds(ψ; at=site"1", dir=:right) !== nothing + @test inds(ψ; at=site"3", dir=:left) == inds(ψ; at=site"2", dir=:right) !== nothing + @test all(i -> size(ψ, inds(ψ; at=Site(i))) == 2, 1:nsites(ψ)) + + @testset "identity constructor" begin + nsites_cases = [6, 7, 6, 7] + physdim_cases = [3, 2, 3, 2] + maxdim_cases = [nothing, nothing, 9, 4] # nothing means default + expected_tensorsizes_cases = [ + [(3, 3), (3, 3, 9), (3, 9, 27), (3, 27, 9), (3, 9, 3), (3, 3)], + [(2, 2), (2, 2, 4), (2, 4, 8), (2, 8, 8), (2, 8, 4), (2, 4, 2), (2, 2)], + [(3, 3), (3, 3, 9), (3, 9, 9), (3, 9, 9), (3, 9, 3), (3, 3)], + [(2, 2), (2, 2, 4), (2, 4, 4), (2, 4, 4), (2, 4, 4), (2, 4, 2), (2, 2)], + ] + + for (nsites, physdim, expected_tensorsizes, maxdim) in + zip(nsites_cases, physdim_cases, expected_tensorsizes_cases, maxdim_cases) + ψ = if isnothing(maxdim) + MPS(identity, nsites; physdim=physdim) + else + MPS(identity, nsites; physdim=physdim, maxdim=maxdim) + end + + # Test the tensor dimensions + obtained_tensorsizes = size.(tensors(ψ)) + @test obtained_tensorsizes == expected_tensorsizes + + # Test whether all tensors are the identity + alltns = tensors(ψ) + + # - Test extreme tensors (2D) equal identity + diagonal_2D = [fill(i, 2) for i in 1:physdim] + @test all(delta -> alltns[1][delta...] == 1, diagonal_2D) + @test sum(alltns[1]) == physdim + @test all(delta -> alltns[end][delta...] == 1, diagonal_2D) + @test sum(alltns[end]) == physdim + + # - Test bulk tensors (3D) equal identity + diagonal_3D = [fill(i, 3) for i in 1:physdim] + @test all(tns -> all(delta -> tns[delta...] == 1, diagonal_3D), alltns[2:(end - 1)]) + @test all(tns -> sum(tns) == physdim, alltns[2:(end - 1)]) + + # Test whether the contraction gives the identity + contracted_ψ = contract(ψ) + diagonal_nsitesD = [fill(i, nsites) for i in 1:physdim] + @test all(delta -> contracted_ψ[delta...] == 1, diagonal_nsitesD) + @test sum(contracted_ψ) == physdim + end + end + + @testset "Site" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test isnothing(sites(ψ, site"1"; dir=:left)) + @test isnothing(sites(ψ, site"3"; dir=:right)) + + @test sites(ψ, site"2"; dir=:left) == site"1" + @test sites(ψ, site"3"; dir=:left) == site"2" + + @test sites(ψ, site"2"; dir=:right) == site"3" + @test sites(ψ, site"1"; dir=:right) == site"2" + end + + @testset "adjoint" begin + ψ = rand(MPS; n=3, maxdim=2, eltype=ComplexF64) + @test socket(ψ') == State(; dual=true) + @test isapprox(contract(ψ), conj(contract(ψ'))) + end + + @testset "truncate!" begin + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2)]) + canonize_site!(ψ, Site(2); direction=:right, method=:svd) + + truncated = truncate(ψ, [site"2", site"3"]; maxdim=1) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + + singular_values = tensors(ψ; between=(site"2", site"3")) + truncated = truncate(ψ, [site"2", site"3"]; threshold=singular_values[2] + 0.1) + @test size(truncated, inds(truncated; bond=[site"2", site"3"])) == 1 + end + + @testset "norm" begin + using LinearAlgebra: norm + + n = 8 + χ = 10 + ψ = rand(MPS; n, maxdim=χ) + + @test socket(ψ) == State() + @test nsites(ψ; set=:inputs) == 0 + @test nsites(ψ; set=:outputs) == n + @test issetequal(sites(ψ), map(Site, 1:n)) + @test boundary(ψ) == Open() + @test isapprox(norm(ψ), 1.0) + @test maximum(last, size(ψ)) <= χ + end + + @testset "normalize!" begin + using LinearAlgebra: normalize! + + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + normalize!(ψ, Site(3)) + @test isapprox(norm(ψ), 1.0) + end + + @testset "canonize_site!" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4)]) + + @test_throws ArgumentError canonize_site!(ψ, Site(1); direction=:left) + @test_throws ArgumentError canonize_site!(ψ, Site(3); direction=:right) + + for method in [:qr, :svd] + canonized = canonize_site(ψ, site"1"; direction=:right, method=method) + @test isleftcanonical(canonized, site"1") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"2"; direction=:right, method=method) + @test isleftcanonical(canonized, site"2") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"2"; direction=:left, method=method) + @test isrightcanonical(canonized, site"2") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + + canonized = canonize_site(ψ, site"3"; direction=:left, method=method) + @test isrightcanonical(canonized, site"3") + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + end + + # Ensure that svd creates a new tensor + @test length(tensors(canonize_site(ψ, Site(2); direction=:left, method=:svd))) == 4 + end + + @testset "canonize!" begin + using Tenet: isleftcanonical, isrightcanonical + + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = canonize(ψ) + + @test form(canonized) isa Canonical + + @test length(tensors(canonized)) == 9 # 5 tensors + 4 singular values vectors + @test isapprox(contract(transform(TensorNetwork(canonized), Tenet.HyperFlatten())), contract(ψ)) + @test isapprox(norm(ψ), norm(canonized)) + + # Extract the singular values between each adjacent pair of sites in the canonized chain + Λ = [tensors(canonized; between=(Site(i), Site(i + 1))) for i in 1:4] + @test map(λ -> sum(abs2, λ), Λ) ≈ ones(length(Λ)) * norm(canonized)^2 + + for i in 1:5 + canonized = canonize(ψ) + + if i == 1 + @test isleftcanonical(canonized, Site(i)) + elseif i == 5 # in the limits of the chain, we get the norm of the state + normalize!(tensors(canonized; bond=(Site(i - 1), Site(i)))) + contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) + @test isleftcanonical(canonized, Site(i)) + else + contract!(canonized; between=(Site(i - 1), Site(i)), direction=:right) + @test isleftcanonical(canonized, Site(i)) + end + end + + for i in 1:5 + canonized = canonize(ψ) + + if i == 1 # in the limits of the chain, we get the norm of the state + normalize!(tensors(canonized; bond=(Site(i), Site(i + 1)))) + contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) + @test isrightcanonical(canonized, Site(i)) + elseif i == 5 + @test isrightcanonical(canonized, Site(i)) + else + contract!(canonized; between=(Site(i), Site(i + 1)), direction=:left) + @test isrightcanonical(canonized, Site(i)) + end + end + end + + @testset "mixed_canonize!" begin + ψ = MPS([rand(4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4, 4), rand(4, 4)]) + canonized = mixed_canonize(ψ, site"3") + + @test form(canonized) isa MixedCanonical + @test form(canonized).orthog_center == site"3" + + @test isisometry(canonized, site"1"; dir=:right) + @test isisometry(canonized, site"2"; dir=:right) + @test isisometry(canonized, site"4"; dir=:left) + @test isisometry(canonized, site"5"; dir=:left) + + @test contract(canonized) ≈ contract(ψ) + end + + @testset "expect" begin + i, j = 2, 3 + mat = reshape(kron(LinearAlgebra.I(2), LinearAlgebra.I(2)), 2, 2, 2, 2) + gate = Quantum(mat, [Site(i), Site(j), Site(i; dual=true), Site(j; dual=true)]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @test isapprox(expect(ψ, [gate]), norm(ψ)^2) + end + + @testset "evolve!" begin + @testset "one site" begin + i = 2 + mat = reshape(LinearAlgebra.I(2), 2, 2) + gate = Quantum(mat, [Site(i), Site(i; dual=true)]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @testset "NonCanonical" begin + ϕ = deepcopy(ψ) + evolve!(ϕ, gate; threshold=1e-14) + @test length(tensors(ϕ)) == 5 + @test issetequal(size.(tensors(ϕ)), [(2, 2), (2, 2, 2), (2, 2, 2), (2, 2, 2), (2, 2)]) + @test isapprox(contract(ϕ), contract(ψ)) + end + + @testset "Canonical" begin + ϕ = deepcopy(ψ) + canonize!(ϕ) + evolve!(ϕ, gate; threshold=1e-14) + @test issetequal(size.(tensors(ϕ)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) + @test isapprox(contract(ϕ), contract(ψ)) + end + end + + @testset "two sites" begin + mat = reshape(LinearAlgebra.I(4), 2, 2, 2, 2) + gate = Quantum(mat, [site"2", site"3", site"2'", site"3'"]) + ψ = MPS([rand(2, 2), rand(2, 2, 2), rand(2, 2, 2), rand(2, 2)]) + + @testset "NonCanonical" begin + ϕ = deepcopy(ψ) + evolve!(ϕ, gate; threshold=1e-14) + @test length(tensors(ϕ)) == 5 + @test issetequal(size.(tensors(ϕ)), [(2, 2), (2, 2, 2), (2,), (2, 2, 2), (2, 2, 2), (2, 2)]) + @test isapprox(contract(ϕ), contract(ψ)) + end + + @testset "Canonical" begin + ψ = deepcopy(ψ) + canonize!(ψ) + evolved = evolve!(deepcopy(ψ), gate; threshold=1e-14) + @test isapprox(contract(evolved), contract(ψ)) + @test issetequal(size.(tensors(evolved)), [(2, 2), (2,), (2, 2, 2), (2,), (2, 2, 2), (2,), (2, 2)]) + end + end + end + + # TODO rename when method is renamed + @testset "contract between" begin + ψ = rand(MPS; n=5, maxdim=20) + let canonized = canonize(ψ) + @test_throws ArgumentError contract!(canonized; between=(site"1", site"2"), direction=:dummy) + end + + canonized = canonize(ψ) + + for i in 1:4 + contract_some = contract(canonized; between=(Site(i), Site(i + 1))) + Bᵢ = tensors(contract_some; at=Site(i)) + + @test isapprox(contract(contract_some), contract(ψ)) + @test_throws Tenet.MissingSchmidtCoefficientsException tensors( + contract_some; between=(Site(i), Site(i + 1)) + ) + + @test isrightcanonical(contract_some, Site(i)) + @test isleftcanonical(contract(canonized; between=(Site(i), Site(i + 1)), direction=:right), Site(i + 1)) + + Γᵢ = tensors(canonized; at=Site(i)) + Λᵢ₊₁ = tensors(canonized; between=(Site(i), Site(i + 1))) + @test Bᵢ ≈ contract(Γᵢ, Λᵢ₊₁; dims=()) + end + end +end diff --git a/test/integration/ChainRules_test.jl b/test/integration/ChainRules_test.jl index ae5472abb..0ef7d8ac4 100644 --- a/test/integration/ChainRules_test.jl +++ b/test/integration/ChainRules_test.jl @@ -205,4 +205,20 @@ test_frule(Product, Ansatz(tn)) test_rrule(Product, Ansatz(tn)) end + + @testset "MPS" begin + @testset "Open" begin + tn = MPS([ones(2, 2), ones(2, 2, 2), ones(2, 2)]) + # test_frule(MPS, Ansatz(tn), form(tn)) + test_rrule(MPS, Ansatz(tn), form(tn)) + end + end + + @testset "MPO" begin + @testset "Open" begin + tn = MPO([ones(2, 2, 2), ones(2, 2, 2, 2), ones(2, 2, 2)]) + # test_frule(MPO, Ansatz(tn), form(tn)) + test_rrule(MPO, Ansatz(tn), form(tn)) + end + end end diff --git a/test/runtests.jl b/test/runtests.jl index 6a19c156c..6c42ea655 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -15,7 +15,8 @@ include("Utils.jl") include("Lattice_test.jl") include("Ansatz_test.jl") include("Product_test.jl") - include("Chain_test.jl") + include("MPS_test.jl") + include("MPO_test.jl") end @testset "Integration tests" verbose = true begin