diff --git a/ext/TenetReactantExt.jl b/ext/TenetReactantExt.jl index 8cdca599e..a972eaf55 100644 --- a/ext/TenetReactantExt.jl +++ b/ext/TenetReactantExt.jl @@ -13,7 +13,7 @@ function Reactant.make_tracer( seen, @nospecialize(prev::RT), path::Tuple, mode::Reactant.TraceMode; kwargs... ) where {RT<:Tensor} tracedata = Reactant.make_tracer(seen, parent(prev), Reactant.append_path(path, :data), mode; kwargs...) - return Tensor(tracedata, inds(prev)) + return Tensor(tracedata, copy(inds(prev))) end function Reactant.make_tracer(seen, prev::TensorNetwork, path::Tuple, mode::Reactant.TraceMode; kwargs...) @@ -42,16 +42,16 @@ function Reactant.make_tracer(seen, prev::Tenet.Product, path::Tuple, mode::Reac return Tenet.Product(tracetn) end -for A in (MPS, MPO) - @eval function Reactant.make_tracer(seen, prev::$A, path::Tuple, mode::Reactant.TraceMode; kwargs...) - tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) - return $A(tracetn, form(prev)) - end +function Reactant.make_tracer( + seen, prev::A, path::Tuple, mode::Reactant.TraceMode; kwargs... +) where {A<:Tenet.AbstractMPO} + tracetn = Reactant.make_tracer(seen, Ansatz(prev), Reactant.append_path(path, :tn), mode; kwargs...) + return A(tracetn, copy(form(prev))) end function Reactant.create_result(@nospecialize(tocopy::Tensor), @nospecialize(path), result_stores) data = Reactant.create_result(parent(tocopy), Reactant.append_path(path, :data), result_stores) - return :($Tensor($data, $(inds(tocopy)))) + return :($Tensor($data, $(copy(inds(tocopy))))) end function Reactant.create_result(tocopy::TensorNetwork, @nospecialize(path), result_stores) @@ -77,26 +77,11 @@ function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), resu return :($(Tenet.Product)($tn)) end -for A in (MPS, MPO) - @eval function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:$A} - tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) - return :($A($tn, $(Tenet.form(tocopy)))) - end +function Reactant.create_result(tocopy::A, @nospecialize(path), result_stores) where {A<:Tenet.AbstractMPO} + tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) + return :($A($tn, $(Tenet.form(tocopy)))) end -# TODO try rely on generic fallback for ansatzes -# function Reactant.create_result(tocopy::Tenet.Product, @nospecialize(path), result_stores) -# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) -# return :($(Tenet.Product)($tn)) -# end - -# for A in (MPS, MPO) -# @eval function Reactant.create_result(tocopy::$A, @nospecialize(path), result_stores) -# tn = Reactant.create_result(Ansatz(tocopy), Reactant.append_path(path, :tn), result_stores) -# return :($A($tn, form(tocopy))) -# end -# end - function Reactant.push_val!(ad_inputs, x::TensorNetwork, path) @assert length(path) == 2 @assert path[2] === :data @@ -216,7 +201,14 @@ end Tenet.contract(a::Tensor, b::Tensor{T,N,TracedRArray{T,N}}; kwargs...) where {T,N} = contract(b, a; kwargs...) function Tenet.contract(a::Tensor{Ta,Na,TracedRArray{Ta,Na}}, b::Tensor{Tb,Nb}; kwargs...) where {Ta,Na,Tb,Nb} - return contract(a, Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, parent(b)), inds(b)); kwargs...) + # TODO change to `Ops.constant` when Ops PR lands in Reactant + # apparently `promote_to` doesn't do the transpostion for converting from column-major (Julia) to row-major layout (MLIR) + # currently, we call permutedims manually + return contract( + a, + Tensor(Reactant.promote_to(TracedRArray{Tb,Nb}, permutedims(parent(b), collect(Nb:-1:1))), inds(b)); + kwargs..., + ) end end diff --git a/ext/TenetYaoBlocksExt.jl b/ext/TenetYaoBlocksExt.jl index 4f902f9b6..4033d532c 100644 --- a/ext/TenetYaoBlocksExt.jl +++ b/ext/TenetYaoBlocksExt.jl @@ -25,13 +25,13 @@ function Tenet.Quantum(circuit::AbstractBlock) end # NOTE `YaoBlocks.mat` on m-site qubits still returns the operator on the full Hilbert space + m = length(occupied_locs(gate)) operator = if gate isa YaoBlocks.ControlBlock - m = length(occupied_locs(gate)) control((1:(m - 1))..., m => content(gate))(m) else content(gate) end - array = reshape(mat(operator), fill(nlevel(operator), 2 * nqubits(operator))...) + array = reshape(collect(mat(operator)), fill(nlevel(operator), 2 * nqubits(operator))...) inds = (x -> collect(Iterators.flatten(zip(x...))))( map(occupied_locs(gate)) do l diff --git a/src/Ansatz.jl b/src/Ansatz.jl index fac055e02..76e4dfb61 100644 --- a/src/Ansatz.jl +++ b/src/Ansatz.jl @@ -33,6 +33,8 @@ Abstract type representing the canonical form trait of a [`AbstractAnsatz`](@ref """ abstract type Form end +Base.copy(x::Form) = x + """ NonCanonical @@ -52,6 +54,8 @@ struct MixedCanonical <: Form orthog_center::Union{Site,Vector{<:Site}} end +Base.copy(x::MixedCanonical) = MixedCanonical(copy(x.orthog_center)) + """ Canonical diff --git a/src/Site.jl b/src/Site.jl index 4e865c93f..a4c842943 100644 --- a/src/Site.jl +++ b/src/Site.jl @@ -15,6 +15,8 @@ end Site(id::Int; kwargs...) = Site((id,); kwargs...) Site(id::Vararg{Int,N}; kwargs...) where {N} = Site(id; kwargs...) +Base.copy(x::Site) = x + id(site::Site{1}) = only(site.id) id(site::Site) = site.id