Skip to content

Commit

Permalink
Reimplement MPS and MPO (#232)
Browse files Browse the repository at this point in the history
* Prototype `MPS`, `MPO`

* Implement `rand`, `adjoint`, `defaultorder`, `boundary`, `form` for `MPS`, `MPO`

* Implement conversion from `Product` to `MPS`, `MPO`

* Refactor `MPS`, `MPO` on top of new `Ansatz` type

* Move `Chain` code to `AbstractAnsatz` and `MPS`

* Fix `sites` method for `MPS`

* Fix `inds` method for `MPS`

* Refactor `adapt_structure` method to support additional types

* Refactor `Reactant.make_tracer`, `Reactant.create_result` methods on top of recent changes

* Refactor `ChainRules` methods on top of new types

* Refactor `rand` for `MPS`, `MPO`

* Refactor `Chain` tests on top of `MPS`, `MPO`

* Try using more `@site_str` instead of `Site` in MPS tests

* Implement some `sites`, `inds` methods for `MPO`

* Try using more `@site_str` in MPO tests

* Fix typo in `mixed_canonize!`

* Fix `truncate` tests on `MPS`

* Refactor some tests of `MPS` to simplify

* Fix typo in `normalize!` on `MPS` method

* Fix typo

* Deprecate `isleftcanonical`, `isrightcanonical` in favor of `isisometry`

* Fix `isleftcanonical`, `isrightcanonical` tests on boundary sites

* Fix `evolve!` calls in tests

* Refactor MPO tests

* Stop orthogonalization to index on `mixed_canonize!`

* Aesthetic name fix

* Stop using `IdDict` on Reactant extension

* Fix `create_result` on `MPS`, `MPO`

* Refactor lattice generation in constructors of `Dense`, `Product`, `MPS`, `MPO`, `PEPS`

* Implement an MPS method initializing the tensors to identity (copy-tensors) (#218)

* Format code

* Implement MPS identity initialization

* Add tests for all dispatches of MPS identity init

* Format julia code

* Rename function header & add docstring

* Fix test set for identity MPS

* Format code

* Rewrite MPS identity init function to nsites instead of arrays' dimensions

* Format julia code

* Update docstring of identity

* Clean code in test (suggested by Jofre)

Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com>

* Format julia code

* Refactor virtualdims in identity (suggested by Sergio)

* Update src/Ansatz/MPS.jl

* Restrict to default order in identity MPS

* Update src/Ansatz/MPS.jl

* Remove order parameter in identity

---------

Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com>
Co-authored-by: Sergio Sánchez Ramírez <15837247+mofeing@users.noreply.github.com>

* move files

* fix constructors

* Document types

* Remove unimplemented `evolve!` method

* fix mutability of `MPO`

* Move `MPO` code to "MPS.jl" and refactor common code

* document `MPS`, `MPO` constructors

* document `rand` on `MPS`, `MPO`

* move some docstrings to `AbstractAnsatz`

* Refactor `normalize!`

* Fix `defaultorder`

* fix `normalize!`

* apply `isisometry` docstring suggestion by @starsfordummies

* add shortcut for `normalize!` with mixed canonization

* fix MPS identity constructor test

* implement shortcut `Quantum` constructor for simple gates

* fix test

* refactor exported names

* fix `mixed_canonize!` tests

* fix `canonize!`, `mixed_canonize!`

* import missing symbols to tests

* fix field name of `MixedCanonical`

* fix namespace clash with `truncate`

* fix `truncate!`

* fix tests

* try fix `mixed_canonize!`, `normalize!`

* fix keyword args of `simple_update!` call

* comment

* fix `MPO` test

* more fixes

* fix test

* refactor legacy `simple_update!` on `Canonical` form

* fix symbol in test

* rename testset

* fix `reindex!`

* Remove legacy `@show`

* format code

* refactor `evolve!` tests

* refactor `evolve!` tests again

* fix indexing in `simple_update!`

* fix wrong call to `canonize!`

* try fix forward-mode diff of `MPS`, `MPO` constructors

---------

Co-authored-by: Todorbsc <145352308+Todorbsc@users.noreply.github.com>
Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 11, 2024
1 parent 8a576c5 commit 4d57ecb
Show file tree
Hide file tree
Showing 17 changed files with 1,030 additions and 586 deletions.
2 changes: 2 additions & 0 deletions ext/TenetAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tens
Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites)
Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), Tenet.lattice(x))
Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Ansatz(x)))
Adapt.adapt_structure(to, x::MPS) = MPS(adapt(to, Ansatz(x)), form(x))
Adapt.adapt_structure(to, x::MPO) = MPO(adapt(to, Ansatz(x)), form(x))

end
2 changes: 2 additions & 0 deletions ext/TenetChainRulesCoreExt/frules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ end

# `AbstractAnsatz`-subtype constructors
ChainRulesCore.frule((_, ẋ), ::Type{Product}, x::Ansatz) = Product(x), Tangent{Product}(; tn=ẋ)
ChainRulesCore.frule((_, ẋ), ::Type{MPS}, x::Ansatz, form) = MPS(x, form), Tangent{MPS}(; tn=ẋ, form=NoTangent())
ChainRulesCore.frule((_, ẋ), ::Type{MPO}, x::Ansatz, form) = MPO(x, form), Tangent{MPO}(; tn=ẋ, form=NoTangent())

# `Base.conj` methods
ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ)
Expand Down
8 changes: 8 additions & 0 deletions ext/TenetChainRulesCoreExt/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@ Product_pullback(ȳ) = (NoTangent(), ȳ.tn)
Product_pullback(ȳ::AbstractThunk) = Product_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Product}, x::Ansatz) = Product(x), Product_pullback

MPS_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
MPS_pullback(ȳ::AbstractThunk) = MPS_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{MPS}, x::Ansatz, form) = MPS(x, form), MPS_pullback

MPO_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
MPO_pullback(ȳ::AbstractThunk) = MPO_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{MPO}, x::Ansatz, form) = MPO(x, form), MPO_pullback

# `Base.conj` methods
conj_pullback::Tensor) = (NoTangent(), conj(Δ))
conj_pullback::Tangent{Tensor}) = (NoTangent(), conj(Δ))
Expand Down
5 changes: 5 additions & 0 deletions ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,11 @@ function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Ansatz)
end

ChainRulesTestUtils.rand_tangent(::AbstractRNG, lattice::Tenet.Lattice) = NoTangent()
ChainRulesTestUtils.test_approx(::AbstractZero, form::Tenet.Lattice, msg=""; kwargs...) = true
ChainRulesTestUtils.test_approx(actual::Tenet.Lattice, expected::Tenet.Lattice, msg; kwargs...) = actual == expected

ChainRulesTestUtils.rand_tangent(::AbstractRNG, form::Tenet.Form) = NoTangent()
ChainRulesTestUtils.test_approx(::AbstractZero, form::Tenet.Form, msg=""; kwargs...) = true
ChainRulesTestUtils.test_approx(actual::Tenet.Form, expected::Tenet.Form, msg; kwargs...) = actual == expected

end
10 changes: 4 additions & 6 deletions ext/TenetQuacExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,11 @@ module TenetQuacExt
using Tenet
using Quac: Gate, Circuit, lanes, arraytype, Swap

# function Tenet.Dense(gate::Gate)
# return Tenet.Dense(
# Operator(), arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...]
# )
# end
function Tenet.Quantum(gate::Gate)
return Tenet.Quantum(arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...])
end

# Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Tenet.Dense(gate); kwargs...)
Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Quantum(gate); kwargs...)

function Tenet.Quantum(circuit::Circuit)
n = lanes(circuit)
Expand Down
14 changes: 14 additions & 0 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,13 @@ function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reac
return Tenet.Product(tracetn)
end

for A in (MPS, MPO)
@eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return $A(tracetn, form(prev))
end
end

function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores)
data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores)
return :($Tensor($data, $(inds(tocopy))))
Expand Down Expand Up @@ -70,6 +77,13 @@ function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), resu
return :($(Tenet.Product)($tn))
end

for A in (MPS, MPO)
@eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A}
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn, $(Tenet.form(tocopy))))
end
end

# TODO try rely on generic fallback for ansatzes
# function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores)
# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
Expand Down
202 changes: 67 additions & 135 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ struct NonCanonical <: Form end
[`Form`](@ref) trait representing a [`AbstractAnsatz`](@ref) Tensor Network in mixed-canonical form.
"""
struct MixedCanonical <: Form
orthogonality_center::Union{Site,Vector{Site}}
orthog_center::Union{Site,Vector{Site}}
end

"""
Expand Down Expand Up @@ -192,15 +192,52 @@ Contract the virtual bond between two [`Site`](@ref)s in a [`AbstractAnsatz`](@r
"""
@kwmethod contract!(tn::AbstractAnsatz; bond) = contract!(tn, inds(tn; bond))

"""
canonize!(tn::AbstractAnsatz)
Transform an [`AbstractAnsatz`](@ref) Tensor Network into the canonical form (aka Vidal gauge); i.e. the singular values matrix Λᵢ between each tensor Γᵢ₋₁ and Γᵢ.
"""
function canonize! end

"""
canonize(tn::AbstractAnsatz)
Like [`canonize!`](@ref), but returns a new Tensor Network instead of modifying the original one.
"""
canonize(tn::AbstractAnsatz, args...; kwargs...) = canonize!(deepcopy(tn), args...; kwargs...)

"""
mixed_canonize!(tn::AbstractAnsatz, orthog_center)
Transform an [`AbstractAnsatz`](@ref) Tensor Network into the mixed-canonical form, that is,
for `i < orthog_center` the tensors are left-canonical and for `i >= orthog_center` the tensors are right-canonical,
and in the `orthog_center` there is a tensor with the Schmidt coefficients in it.
"""
function mixed_canonize! end

"""
mixed_canonize(tn::AbstractAnsatz, orthog_center)
Like [`mixed_canonize!`](@ref), but returns a new Tensor Network instead of modifying the original one.
"""
mixed_canonize(tn::AbstractAnsatz, args...; kwargs...) = mixed_canonize!(deepcopy(tn), args...; kwargs...)

canonize_site(tn::AbstractAnsatz, args...; kwargs...) = canonize_site!(deepcopy(tn), args...; kwargs...)

"""
isisometry(tn::AbstractAnsatz, site; dir, kwargs...)
Check if the tensor at a given [`Site`](@ref) in a [`AbstractAnsatz`](@ref) Tensor Network is an isometry.
The `dir` keyword argument specifies the direction of the isometry to check.
"""
function isisometry end

"""
truncate(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing)
Like [`truncate!`](@ref), but returns a new tensor network instead of modifying the original one.
Like [`truncate!`](@ref), but returns a new Tensor Network instead of modifying the original one.
"""
truncate(tn::AbstractAnsatz, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...)
Base.truncate(tn::AbstractAnsatz, args...; kwargs...) = truncate!(deepcopy(tn), args...; kwargs...)

"""
truncate!(tn::AbstractAnsatz, bond; threshold = nothing, maxdim = nothing)
Expand Down Expand Up @@ -236,15 +273,21 @@ Truncate the dimension of the virtual `bond` of a [`NonCanonical`](@ref) Tensor
- `compute_local_svd`: Whether to compute the local SVD of the bond. If `true`, it will contract the bond and perform a SVD to get the local singular values. Defaults to `true`.
"""
function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim, compute_local_svd=true)
virtualind = inds(tn; bond)

if compute_local_svd
tₗ = tensors(tn; at=min(bond...))
tᵣ = tensors(tn; at=max(bond...))
contract!(tn; bond)
svd!(tn; virtualind=inds(tn; bond))

left_inds = filter(!=(virtualind), inds(tₗ))
right_inds = filter(!=(virtualind), inds(tᵣ))
svd!(tn; left_inds, right_inds, virtualind=virtualind)
end

spectrum = parent(tensors(tn; bond))
vind = inds(tn; bond)

maxdim = isnothing(maxdim) ? size(tn, vind) : maxdim
maxdim = isnothing(maxdim) ? size(tn, virtualind) : maxdim

extent = if isnothing(threshold)
1:maxdim
Expand All @@ -254,7 +297,7 @@ function truncate!(::NonCanonical, tn::AbstractAnsatz, bond; threshold, maxdim,
end - 1, maxdim)
end

slice!(tn, vind, extent)
slice!(tn, virtualind, extent)

return tn
end
Expand All @@ -268,7 +311,7 @@ end
function truncate!(::Canonical, tn::AbstractAnsatz, bond; threshold, maxdim)
truncate!(NonCanonical(), tn, bond; threshold, maxdim, compute_local_svd=false)
# requires a sweep to recanonize the TN
return canonize!(tn, bond)
return canonize!(tn)
end

overlap(a::AbstractAnsatz, b::AbstractAnsatz) = contract(merge(a, copy(b)'))
Expand Down Expand Up @@ -333,7 +376,7 @@ function simple_update!(ψ::AbstractAnsatz, gate; threshold=nothing, maxdim=noth

@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

return simple_update!(form(ψ), ψ, gate; kwargs...)
return simple_update!(form(ψ), ψ, gate; threshold, maxdim, kwargs...)
end

# TODO a lot of problems with merging... maybe we shouldn't merge manually
Expand Down Expand Up @@ -376,9 +419,19 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth
rinds = filter(!=(vind), inds(tensors(ψ; at=siter)))
contract!(ψ; bond)

# TODO replace for `merge!` when #243 is fixed
# reindex contracting indices to temporary names to avoid issues
oinds = Dict(site => inds(ψ; at=site) for site in sites(gate; set=:outputs))
tmpinds = Dict(site => gensym(:tmp) for site in sites(gate; set=:inputs))
replace!(gate, [inds(gate; at=site) => i for (site, i) in tmpinds])
replace!(ψ, [inds(ψ; at=site') => i for (site, i) in tmpinds])

# NOTE `replace!` is getting confused when a index is already there even if it would be overriden
# TODO fix this to be handled in one call -> replace when #244 is fixed
replace!(gate, [inds(gate; at=site) => gensym() for (site, i) in oinds])
replace!(gate, [inds(gate; at=site) => i for (site, i) in oinds])

# contract physical inds with gate
@reindex! outputs(ψ) => outputs(gate) reset = false
@reindex! inputs(gate) => outputs(ψ) reset = false
merge!(ψ, gate; reset=false)
contract!(ψ, inds(gate; set=:inputs))

Expand All @@ -395,129 +448,8 @@ function simple_update!(::NonCanonical, ψ::AbstractAnsatz, gate; threshold=noth
end

# TODO remove `renormalize` argument?
# TODO refactor code
# TODO optimize correctly -> avoid recanonization + use lateral Λs
function simple_update!(::Canonical, ψ::AbstractAnsatz, gate; threshold, maxdim, renormalize=false)
@assert nlanes(gate) == 2 "Only 2-site gates are supported currently"
@assert has_edge(ψ, lanes(gate)...) "Gate must act on neighboring sites"

# shallow copy to avoid problems if errors in mid execution
gate = copy(gate)

bond = sitel, siter = minmax(sites(gate; set=:outputs)...)
left_inds::Vector{Symbol} = !isnothing(leftindex(ψ, sitel)) ? [leftindex(ψ, sitel)] : Symbol[]
right_inds::Vector{Symbol} = !isnothing(rightindex(ψ, siter)) ? [rightindex(ψ, siter)] : Symbol[]

virtualind::Symbol = inds(ψ; bond=bond)

contract_2sitewf!(ψ, bond)

# reindex contracting index
contracting_inds = [gensym(:tmp) for _ in sites(gate; set=:inputs)]
replace!(
ψ,
map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index)
inds(ψ; at=site') => contracting_index
end,
)
replace!(
gate,
map(zip(sites(gate; set=:inputs), contracting_inds)) do (site, contracting_index)
inds(gate; at=site) => contracting_index
end,
)

# replace output indices of the gate for gensym indices
output_inds = [gensym(:out) for _ in sites(gate; set=:outputs)]
replace!(
gate,
map(zip(sites(gate; set=:outputs), output_inds)) do (site, out)
inds(gate; at=site) => out
end,
)

# reindex output of gate to match TN sitemap
for site in sites(gate; set=:outputs)
if inds(ψ; at=site) != inds(gate; at=site)
replace!(gate, inds(gate; at=site) => inds(ψ; at=site))
end
end

# contract physical inds
merge!(ψ, gate)
contract!(ψ, contracting_inds)

# decompose using SVD
push!(left_inds, inds(ψ; at=sitel))
push!(right_inds, inds(ψ; at=siter))

unpack_2sitewf!(ψ, bond, left_inds, right_inds, virtualind)

# truncate virtual index
if any(!isnothing, [threshold, maxdim])
truncate!(ψ, bond; threshold, maxdim)
renormalize && normalize!(tensors(ψ; between=bond))
end

return ψ
end

# TODO refactor code
"""
contract_2sitewf!(ψ::AbstractAnsatz, bond)
For a given [`AbstractAnsatz`](@ref) in the canonical form, creates the two-site wave function θ with Λᵢ₋₁Γᵢ₋₁ΛᵢΓᵢΛᵢ₊₁,
where i is the `bond`, and replaces the Γᵢ₋₁ΛᵢΓᵢ tensors with θ.
"""
function contract_2sitewf!::AbstractAnsatz, bond)
@assert form(ψ) == Canonical() "The tensor network must be in canonical form"

sitel, siter = bond # TODO Check if bond is valid
(0 < id(sitel) < nsites(ψ) || 0 < id(siter) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel))
Λᵢ₊₁ = id(sitel) == nsites(ψ) - 1 ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1)))

!isnothing(Λᵢ₋₁) && contract!(ψ; between=(Site(id(sitel) - 1), sitel), direction=:right, delete_Λ=false)
!isnothing(Λᵢ₊₁) && contract!(ψ; between=(siter, Site(id(siter) + 1)), direction=:left, delete_Λ=false)

contract!(ψ, inds(ψ; bond=bond))

return ψ
end

# TODO refactor code
"""
unpack_2sitewf!(ψ::AbstractAnsatz, bond)
For a given [`AbstractAnsatz`](@ref) that contains a two-site wave function θ in a bond, it decomposes θ into the canonical
form: Γᵢ₋₁ΛᵢΓᵢ, where i is the `bond`.
"""
function unpack_2sitewf!::AbstractAnsatz, bond, left_inds, right_inds, virtualind)
@assert form(ψ) == Canonical() "The tensor network must be in canonical form"

sitel, siter = bond # TODO Check if bond is valid
(0 < id(sitel) < nsites(ψ) || 0 < id(site_r) < nsites(ψ)) ||
throw(ArgumentError("The sites in the bond must be between 1 and $(nsites(ψ))"))

Λᵢ₋₁ = id(sitel) == 1 ? nothing : tensors(ψ; between=(Site(id(sitel) - 1), sitel))
Λᵢ₊₁ = id(siter) == nsites(ψ) ? nothing : tensors(ψ; between=(siter, Site(id(siter) + 1)))

# do svd of the θ tensor
θ = tensors(ψ; at=sitel)
U, s, Vt = svd(θ; left_inds, right_inds, virtualind)

# contract with the inverse of Λᵢ and Λᵢ₊₂
Γᵢ₋₁ =
isnothing(Λᵢ₋₁) ? U : contract(U, Tensor(diag(pinv(Diagonal(parent(Λᵢ₋₁)); atol=1e-32)), inds(Λᵢ₋₁)); dims=())
Γᵢ =
isnothing(Λᵢ₊₁) ? Vt : contract(Tensor(diag(pinv(Diagonal(parent(Λᵢ₊₁)); atol=1e-32)), inds(Λᵢ₊₁)), Vt; dims=())

delete!(ψ, θ)

push!(ψ, Γᵢ₋₁)
push!(ψ, s)
push!(ψ, Γᵢ)

return ψ
simple_update!(NonCanonical(), ψ, gate; threshold, maxdim, renormalize)
return canonize!(ψ)
end
2 changes: 2 additions & 0 deletions src/Lattice.jl
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,10 @@ end
Base.copy(lattice::Lattice) = Lattice(copy(lattice.mapping), copy(lattice.graph))
Base.:(==)(a::Lattice, b::Lattice) = a.mapping == b.mapping && a.graph == b.graph

# TODO these where needed by ChainRulesTestUtils, do we still need them?
Base.zero(::Type{Lattice}) = Lattice(BijectiveIdDict{Site,Int}(), zero(Graphs.SimpleGraph{Int}))
Base.zero(::Lattice) = zero(Lattice)

Graphs.is_directed(::Type{Lattice}) = false

function Graphs.vertices(lattice::Lattice)
Expand Down
Loading

0 comments on commit 4d57ecb

Please sign in to comment.