Skip to content

Commit

Permalink
Refactor Ansatz to be a concrete type with lattice information (#223)
Browse files Browse the repository at this point in the history
* Introduce Canonical `Form` trait

* Refactor `Ansatz` into a concrete type

* Fix typos in `Ansatz`

* Use `Graphs.neighbors`

* Implement `copy`, `similar`, `zero` for `Ansatz`

* Relax `sites` condition on `Ansatz` construction for `lanes`

* Remove some exports

* Move `Chain` code to `AbstractAnsatz` and `MPS`

* Force Graphs to be a strong dependency

* Implement `Graphs.neighbors`, `isneighbor` methods

* Refactor `adapt_structure` method to support additional types

* Refactor `Reactant.make_tracer`, `Reactant.create_result` methods on top of recent changes

* Refactor `ChainRules` methods on top of new types

* Refactor `ProjectTo` for `Ansatz`

* Export `canonize_site`, `canonize_site!` methods

* Add `Graphs`, `MetaGraphsNext` as test dependencies

* Fix `tensors(; bond)`

* Fix `truncate!`

* Fix `truncate!` extension when using `threshold`

* Format code

* Comment `renormalize` kwarg of `evolve!`

* Fix `simple_update!` on single site gates

* Fix indexing problems in `simple_update_1site!`

* Some fixes on `simple_update!`

* Prototype tests for `Ansatz`

* Fix reference to lattice in `adapt_structure` for `Ansatz`

* Stop using `IdDict` on Reactant extension

* Set temporarily a more concrete type of `lattice` in graph to circunvent a Julia bug

* Delete `Chain`, `Dense`, `Grid`, `Product` types

* Revert "Set temporarily a more concrete type of `lattice` in graph to circunvent a Julia bug"

This reverts commit 7c67f58.

* Partially revert 0bef4a9

* Fix `expect` for single and multiple observables

* Replace `MetaGraph` for new `Lattice` graph type

* fix ambiguous definition of `expect` methods

* test `Ansatz`

* fix typo

* format code

* fix `@testset` with `let`

* fix `Ansatz` constructor

* fix calls to `neighbors` on `Ansatz` tests

* fix typo

* Implement `Base.copy` for `Lattice`

* fix tests

* Reenable `ChainRules` tests on `Ansatz`

* format code

* implement default method for `form`

* finish `Ansatz` tests

* refactor `expect` to avoid calling `evolve!`

* import `Graphs.contract` to avoid namespace conflicts

* delete "Ansatz" folder

* document stuff

* fix `include`

* fix tests

* fix `simple_update!` for 2-local operators

* reenable tests for `simple_update!`, `evolve!` on 2-local operators

* fix tests

* fix `ChainRulesTestUtils.rand_tangent` for `Lattice`

* fix typo

* Implement `ChainRulesTestUtils.test_approx` for `Lattice`

* clean comment

* fix typo again

* Stop installing `ITensorNetworks` on test

* Reinclude skipped test sets of `Product` and `Chain`

Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com>

* Refactor `truncate!` to account for different methods based on canonical forms

---------

Co-authored-by: Sergio Sánchez Ramírez <sergio.sancnchez.ramirez+git@bsc.es>
Co-authored-by: Jofre Vallès Muns <61060572+jofrevalles@users.noreply.github.com>
  • Loading branch information
3 people authored Nov 4, 2024
1 parent a241967 commit a432fa3
Show file tree
Hide file tree
Showing 29 changed files with 936 additions and 1,264 deletions.
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,11 @@ version = "0.8.0-DEV"

[deps]
AbstractTrees = "1520ce14-60c1-5f80-bbc7-55ef81b5835c"
BijectiveDicts = "d089a002-103b-496c-a4e3-58f90cf955b0"
Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
DeltaArrays = "10b0fc19-5ccc-4427-889b-d75dd6306188"
EinExprs = "b1794770-133b-4de1-afb4-526377e9f4c5"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
KeywordDispatch = "5888135b-5456-5c80-a1b6-c91ef8180460"
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
OMEinsum = "ebe7aa44-baf0-506c-a96f-8464559b3922"
Expand All @@ -25,7 +27,6 @@ ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a"
Dagger = "d58978e5-989f-55fb-8d15-ea34adc7bf54"
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
GraphMakie = "1ecd5474-83a3-4783-bb4f-06765db800d2"
Graphs = "86223c79-3864-5bf0-83f7-82e725a168b6"
ITensorNetworks = "2919e153-833c-4bdc-8836-1ea460a35fc7"
ITensors = "9136182c-28ba-11e9-034c-db9fb085ebd5"
KrylovKit = "0b1a1467-8014-51b9-945f-bf0ae24f4b77"
Expand Down Expand Up @@ -54,6 +55,7 @@ TenetYaoBlocksExt = "YaoBlocks"
[compat]
AbstractTrees = "0.4"
Adapt = "4"
BijectiveDicts = "0.1"
ChainRules = "1.0"
ChainRulesCore = "1.0"
ChainRulesTestUtils = "1"
Expand Down
3 changes: 1 addition & 2 deletions ext/TenetAdaptExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@ Adapt.adapt_structure(to, x::Tensor) = Tensor(adapt(to, parent(x)), inds(x))
Adapt.adapt_structure(to, x::TensorNetwork) = TensorNetwork(adapt.(Ref(to), tensors(x)))

Adapt.adapt_structure(to, x::Quantum) = Quantum(adapt(to, TensorNetwork(x)), x.sites)
Adapt.adapt_structure(to, x::Product) = Product(adapt(to, Quantum(x)))
Adapt.adapt_structure(to, x::Chain) = Chain(adapt(to, Quantum(x)), boundary(x))
Adapt.adapt_structure(to, x::Ansatz) = Ansatz(adapt(to, Quantum(x)), Tenet.lattice(x))

end
30 changes: 16 additions & 14 deletions ext/TenetChainRulesCoreExt/frules.jl
Original file line number Diff line number Diff line change
@@ -1,16 +1,30 @@
using Tenet: AbstractTensorNetwork, AbstractQuantum

# `Tensor` constructor
ChainRulesCore.frule((_, Δ, _), T::Type{<:Tensor}, data, inds) = T(data, inds), T(Δ, inds)

# `TensorNetwork` constructor
ChainRulesCore.frule((_, Δ), ::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetworkTangent(Δ)

# `Quantum` constructor
function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites)
return Quantum(x, sites), Tangent{Quantum}(; tn=ẋ, sites=NoTangent())
end

# `Ansatz` constructor
function ChainRulesCore.frule((_, ẋ), ::Type{Ansatz}, x::Quantum, lattice)
return Ansatz(x, lattice), Tangent{Ansatz}(; tn=ẋ, lattice=NoTangent())
end

# `Base.conj` methods
ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::Tensor) = conj(tn), conj(Δ)

ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::TensorNetwork) = conj(tn), conj(Δ)
ChainRulesCore.frule((_, Δ), ::typeof(Base.conj), tn::AbstractTensorNetwork) = conj(tn), conj(Δ)

# `Base.merge` methods
ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::TensorNetwork, b::TensorNetwork) = merge(a, b), merge(ȧ, ḃ)
function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(Base.merge), a::AbstractTensorNetwork, b::AbstractTensorNetwork)
return merge(a, b), merge(ȧ, ḃ)
end

# `contract` methods
function ChainRulesCore.frule((_, ẋ), ::typeof(contract), x::Tensor; kwargs...)
Expand All @@ -22,15 +36,3 @@ function ChainRulesCore.frule((_, ȧ, ḃ), ::typeof(contract), a::Tensor, b::T
= contract(ȧ, b; kwargs...) + contract(a, ḃ; kwargs...)
return c, ċ
end

function ChainRulesCore.frule((_, ẋ, _), ::Type{Quantum}, x::TensorNetwork, sites)
y = Quantum(x, sites)
= Tangent{Quantum}(; tn=ẋ)
return y, ẏ
end

ChainRulesCore.frule((_, ẋ), ::Type{T}, x::Quantum) where {T<:Ansatz} = T(x), Tangent{T}(; super=ẋ)

function ChainRulesCore.frule((_, ẋ, _), ::Type{T}, x::Quantum, boundary) where {T<:Ansatz}
return T(x, boundary), Tangent{T}(; super=ẋ, boundary=NoTangent())
end
7 changes: 2 additions & 5 deletions ext/TenetChainRulesCoreExt/projectors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,5 @@ end
ChainRulesCore.ProjectTo(x::Quantum) = ProjectTo{Quantum}(; tn=ProjectTo(TensorNetwork(x)), sites=x.sites)
(projector::ProjectTo{Quantum})(Δ) = Quantum(projector.tn(Δ), projector.sites)

ChainRulesCore.ProjectTo(x::T) where {T<:Ansatz} = ProjectTo{T}(; super=ProjectTo(Quantum(x)))
(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Ansatz} = T(projector.super.super), Δ.boundary)

# NOTE edge case: `Product` has no `boundary`. should it?
(projector::ProjectTo{T})(Δ::Union{T,Tangent{T}}) where {T<:Product} = T(projector.super.super))
ChainRulesCore.ProjectTo(x::Ansatz) = ProjectTo{Ansatz}(; tn=ProjectTo(Quantum(x)), lattice=x.lattice)
(projector::ProjectTo{Ansatz})(Δ) = Ansatz(projector.tn(Δ), Δ.lattice)
38 changes: 11 additions & 27 deletions ext/TenetChainRulesCoreExt/rrules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,17 @@ TensorNetwork_pullback(Δ::TensorNetworkTangent) = (NoTangent(), tensors(Δ))
TensorNetwork_pullback::AbstractThunk) = TensorNetwork_pullback(unthunk(Δ))
ChainRulesCore.rrule(::Type{TensorNetwork}, tensors) = TensorNetwork(tensors), TensorNetwork_pullback

# `Quantum` constructor
Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent())
Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback

# `Ansatz` constructor
Ansatz_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Ansatz}, x::Quantum, lattice) = Ansatz(x, lattice), Ansatz_pullback

# `Base.conj` methods
conj_pullback::Tensor) = (NoTangent(), conj(Δ))
conj_pullback::Tangent{Tensor}) = (NoTangent(), conj(Δ))
Expand Down Expand Up @@ -93,33 +104,6 @@ function ChainRulesCore.rrule(::typeof(contract), a::Tensor, b::Tensor; kwargs..
return c, contract_pullback
end

Quantum_pullback(ȳ) = (NoTangent(), ȳ.tn, NoTangent())
Quantum_pullback(ȳ::AbstractArray) = (NoTangent(), ȳ, NoTangent())
Quantum_pullback(ȳ::AbstractThunk) = Quantum_pullback(unthunk(ȳ))
ChainRulesCore.rrule(::Type{Quantum}, x::TensorNetwork, sites) = Quantum(x, sites), Quantum_pullback

Ansatz_pullback(ȳ) = (NoTangent(), ȳ.super)
Ansatz_pullback(ȳ::AbstractThunk) = Ansatz_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{T}, x::Quantum) where {T<:Ansatz}
y = T(x)
return y, Ansatz_pullback
end

Ansatz_boundary_pullback(ȳ) = (NoTangent(), ȳ.super, NoTangent())
Ansatz_boundary_pullback(ȳ::AbstractThunk) = Ansatz_boundary_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(::Type{T}, x::Quantum, boundary) where {T<:Ansatz}
return T(x, boundary), Ansatz_boundary_pullback
end

Ansatz_from_arrays_pullback(ȳ) = (NoTangent(), NoTangent(), NoTangent(), parent.(tensors(ȳ.super.tn)))
Ansatz_from_arrays_pullback(ȳ::AbstractThunk) = Ansatz_from_arrays_pullback(unthunk(ȳ))
function ChainRulesCore.rrule(
::Type{T}, socket::Tenet.Socket, boundary::Tenet.Boundary, arrays; kwargs...
) where {T<:Ansatz}
y = T(socket, boundary, arrays; kwargs...)
return y, Ansatz_from_arrays_pullback
end

copy_pullback(ȳ) = (NoTangent(), ȳ)
copy_pullback(ȳ::AbstractThunk) = unthunk(ȳ)
function ChainRulesCore.rrule(::typeof(copy), x::Quantum)
Expand Down
22 changes: 14 additions & 8 deletions ext/TenetChainRulesTestUtilsExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6,27 +6,33 @@ using Tenet
using ChainRulesCore
using ChainRulesTestUtils
using Random
using Graphs

const TensorNetworkTangent = Base.get_extension(Tenet, :TenetChainRulesCoreExt).TensorNetworkTangent

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Vector{T}) where {T<:Tensor}
if isempty(x)
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Vector{T}) where {T<:Tensor}
if isempty(tn)
return Vector{T}()
else
@invoke rand_tangent(rng::AbstractRNG, x::AbstractArray)
@invoke rand_tangent(rng::AbstractRNG, tn::AbstractArray)
end
end

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::TensorNetwork)
return TensorNetworkTangent(Tensor[ProjectTo(tensor)(rand_tangent.(Ref(rng), tensor)) for tensor in tensors(x)])
end

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, x::Quantum)
return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(x)), sites=NoTangent())
function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Quantum)
return Tangent{Quantum}(; tn=rand_tangent(rng, TensorNetwork(tn)), sites=NoTangent())
end

# WARN type-piracy
# NOTE used in `Quantum` constructor
ChainRulesTestUtils.rand_tangent(::AbstractRNG, x::Dict{<:Site,Symbol}) = NoTangent()
ChainRulesTestUtils.rand_tangent(::AbstractRNG, tn::Dict{<:Site,Symbol}) = NoTangent()

function ChainRulesTestUtils.rand_tangent(rng::AbstractRNG, tn::Ansatz)
return Tangent{Ansatz}(; tn=rand_tangent(rng, Quantum(tn)), lattice=NoTangent())
end

ChainRulesTestUtils.rand_tangent(::AbstractRNG, lattice::Tenet.Lattice) = NoTangent()
ChainRulesTestUtils.test_approx(actual::Tenet.Lattice, expected::Tenet.Lattice, msg; kwargs...) = actual == expected

end
14 changes: 14 additions & 0 deletions ext/TenetFiniteDifferencesExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,4 +20,18 @@ function FiniteDifferences.to_vec(x::Dict{Vector{Symbol},Tensor})
return x_vec, Dict_from_vec
end

function FiniteDifferences.to_vec(x::Quantum)
x_vec, back = to_vec(TensorNetwork(x))
Quantum_from_vec(v) = Quantum(back(v), copy(x.sites))

return x_vec, Quantum_from_vec
end

function FiniteDifferences.to_vec(x::Ansatz)
x_vec, back = to_vec(Quantum(x))
Ansatz_from_vec(v) = Ansatz(back(v), copy(x.lattice))

return x_vec, Ansatz_from_vec
end

end
4 changes: 2 additions & 2 deletions ext/TenetGraphMakieExt.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
module TenetGraphMakieExt

using Tenet
using GraphMakie
using Graphs
using Makie
const Graphs = GraphMakie.Graphs
using Tenet
using Combinatorics: combinations
const NetworkLayout = GraphMakie.NetworkLayout

Expand Down
12 changes: 6 additions & 6 deletions ext/TenetQuacExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ module TenetQuacExt
using Tenet
using Quac: Gate, Circuit, lanes, arraytype, Swap

function Tenet.Dense(gate::Gate)
return Tenet.Dense(
Operator(), arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...]
)
end
# function Tenet.Dense(gate::Gate)
# return Tenet.Dense(
# Operator(), arraytype(gate)(gate); sites=Site[Site.(lanes(gate))..., Site.(lanes(gate); dual=true)...]
# )
# end

Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Tenet.Dense(gate); kwargs...)
# Tenet.evolve!(qtn::Ansatz, gate::Gate; kwargs...) = evolve!(qtn, Tenet.Dense(gate); kwargs...)

function Tenet.Quantum(circuit::Circuit)
n = lanes(circuit)
Expand Down
32 changes: 19 additions & 13 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -31,15 +31,9 @@ function Reactant.make_tracer(seen, prev::Quantum, path::Tuple, mode::Reactant.T
return Quantum(tracetn, copy(prev.sites))
end

function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...)
return Tenet.Product(tracequantum)
end

# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO
function Reactant.make_tracer(seen, prev::Tenet.Chain, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracequantum = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :super), mode; kwargs...)
return Tenet.Chain(tracequantum, boundary(prev))
function Reactant.make_tracer(seen, prev::Ansatz, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Quantum(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return Ansatz(tracetn, copy(Tenet.lattice(prev)))
end

function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores)
Expand All @@ -59,12 +53,24 @@ function Reactant.create_result(tocopy::Quantum, @nospecialize(path), result_sto
return :($Quantum($tn, $(copy(tocopy.sites))))
end

# TODO try rely on generic fallback for ansatzes -> do it when refactoring to MPS/MPO
function Reactant.create_result(tocopy::Tenet.Chain, @nospecialize(path), result_stores)
qtn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :super), result_stores)
return :($(Tenet.Chain)($qtn, $(boundary(tocopy))))
function Reactant.create_result(tocopy::Ansatz, @nospecialize(path), result_stores)
tn = Reactant.create_result(Quantum(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($Ansatz($tn, $(copy(Tenet.lattice(tocopy)))))
end

# TODO try rely on generic fallback for ansatzes
# function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores)
# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
# return :($(Tenet.Product)($tn))
# end

# for A in (MPS, MPO)
# @eval function Reactant.create_result(tocopy::$A, @nospecialize(path), result_stores)
# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
# return :($A($tn, form(tocopy)))
# end
# end

function Reactant.push_val!(ad_inputs, x::TensorNetwork, path)
@assert length(path) == 2
@assert path[2] === :data
Expand Down
Loading

0 comments on commit a432fa3

Please sign in to comment.