Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix Array-TracedRArray contraction and force dense representation of Yao gates #258

Merged
merged 8 commits into from
Nov 27, 2024
44 changes: 18 additions & 26 deletions ext/TenetReactantExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ function Reactant.make_tracer(
seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs...
) where {RT<:Tensor}
tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...)
return Tensor(tracedata, inds(prev))
return Tensor(tracedata, copy(inds(prev)))
end

function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...)
Expand Down Expand Up @@ -42,16 +42,16 @@ function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reac
return Tenet.Product(tracetn)
end

for A in (MPS, MPO)
@eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...)
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return $A(tracetn, form(prev))
end
function Reactant.make_tracer(
seen, prev::A, path::Tuple, mode::Reactant.TraceMode; kwargs...
) where {A<:Tenet.AbstractMPO}
tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...)
return A(tracetn, copy(form(prev)))
end

function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores)
data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores)
return :($Tensor($data, $(inds(tocopy))))
return :($Tensor($data, $(copy(inds(tocopy)))))
end

function Reactant.create_result(tocopy::TensorNetwork, @nospecialize(path), result_stores)
Expand All @@ -77,26 +77,11 @@ function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), resu
return :($(Tenet.Product)($tn))
end

for A in (MPS, MPO)
@eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A}
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn, $(Tenet.form(tocopy))))
end
function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:Tenet.AbstractMPO}
tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
return :($A($tn, $(Tenet.form(tocopy))))
end

# TODO try rely on generic fallback for ansatzes
# function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores)
# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
# return :($(Tenet.Product)($tn))
# end

# for A in (MPS, MPO)
# @eval function Reactant.create_result(tocopy::$A, @nospecialize(path), result_stores)
# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores)
# return :($A($tn, form(tocopy)))
# end
# end

function Reactant.push_val!(ad_inputs, x::TensorNetwork, path)
@assert length(path) == 2
@assert path[2] === :data
Expand Down Expand Up @@ -216,7 +201,14 @@ end

Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...)
function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb}
return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...)
# TODO change to `Ops.constant` when Ops PR lands in Reactant
# apparently `promote_to` doesn't do the transpostion for converting from column-major (Julia) to row-major layout (MLIR)
# currently, we call permutedims manually
return contract(
a,
Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, permutedims(parent(b), collect(Nb:-1:1))), inds(b));
kwargs...,
)
end

end
7 changes: 5 additions & 2 deletions ext/TenetYaoBlocksExt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -25,13 +25,16 @@ function Tenet.Quantum(circuit::AbstractBlock)
end

# NOTE `YaoBlocks.mat` on m-site qubits still returns the operator on the full Hilbert space
m = length(occupied_locs(gate))
operator = if gate isa YaoBlocks.ControlBlock
m = length(occupied_locs(gate))
control((1:(m - 1))..., m => content(gate))(m)
else
content(gate)
end
array = reshape(mat(operator), fill(nlevel(operator), 2 * nqubits(operator))...)

# NOTE dim permutation fixes array layout of Yao
perm = collect(Iterators.flatten([m:-1:1, (2m):-1:(m + 1)]))
array = reshape(collect(mat(operator)), fill(nlevel(operator), 2 * nqubits(operator))...)

inds = (x -> collect(Iterators.flatten(zip(x...))))(
map(occupied_locs(gate)) do l
Expand Down
4 changes: 4 additions & 0 deletions src/Ansatz.jl
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,8 @@ Abstract type representing the canonical form trait of a [`AbstractAnsatz`](@ref
"""
abstract type Form end

Base.copy(x::Form) = x

"""
NonCanonical

Expand All @@ -52,6 +54,8 @@ struct MixedCanonical <: Form
orthog_center::Union{Site,Vector{<:Site}}
end

Base.copy(x::MixedCanonical) = MixedCanonical(copy(x.orthog_center))

"""
Canonical

Expand Down
2 changes: 2 additions & 0 deletions src/Site.jl
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@ end
Site(id::Int; kwargs...) = Site((id,); kwargs...)
Site(id::Vararg{Int,N}; kwargs...) where {N} = Site(id; kwargs...)

Base.copy(x::Site) = x

id(site::Site{1}) = only(site.id)
id(site::Site) = site.id

Expand Down
Loading