diff --git a/src/Helpers.jl b/src/Helpers.jl index 05be5be6f..2b4fb1e5a 100644 --- a/src/Helpers.jl +++ b/src/Helpers.jl @@ -49,3 +49,11 @@ resetindex!(gen::IndexCounter) = letter(Threads.atomic_xchg!(gen.counter, 1)) # if is Complex, extract the parametric type and get the eps of that wrap_eps(x) = eps(x) wrap_eps(::Type{Complex{T}}) where {T} = eps(T) + +struct UnsafeScope + refs::Vector{WeakRef} + + UnsafeScope() = new(Vector{WeakRef}()) +end + +Base.values(uc::UnsafeScope) = map(x -> x.value, uc.refs) diff --git a/src/TensorNetwork.jl b/src/TensorNetwork.jl index 1a215a752..68f7c27bd 100644 --- a/src/TensorNetwork.jl +++ b/src/TensorNetwork.jl @@ -41,8 +41,10 @@ struct TensorNetwork <: AbstractTensorNetwork tensormap::IdDict{Tensor,Vector{Symbol}} sorted_tensors::CachedField{Vector{Tensor}} + unsafe::Ref{Union{Nothing,UnsafeScope}} - function TensorNetwork(tensors) + # TODO: Find a way to remove the `unsafe` keyword argument from the constructor + function TensorNetwork(tensors; unsafe::Union{Nothing,UnsafeScope}=nothing) tensormap = IdDict{Tensor,Vector{Symbol}}(tensor => inds(tensor) for tensor in tensors) indexmap = reduce(tensors; init=Dict{Symbol,Vector{Tensor}}()) do dict, tensor @@ -53,25 +55,42 @@ struct TensorNetwork <: AbstractTensorNetwork dict end - # Check for inconsistent index dimensions - for ind in keys(indexmap) - dims = map(tensor -> size(tensor, ind), indexmap[ind]) - length(unique(dims)) == 1 || throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)")) + # Check index size consistency if not inside an `UnsafeScope` + if isnothing(unsafe) + for ind in keys(indexmap) + dims = map(tensor -> size(tensor, ind), indexmap[ind]) + length(unique(dims)) == 1 || + throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)")) + end end - return new(indexmap, tensormap, CachedField{Vector{Tensor}}()) + return new(indexmap, tensormap, CachedField{Vector{Tensor}}(), Ref{Union{Nothing,UnsafeScope}}(unsafe)) end end TensorNetwork() = TensorNetwork(Tensor[]) TensorNetwork(tn::TensorNetwork) = tn +get_unsafe_scope(tn::AbstractTensorNetwork) = TensorNetwork(tn).unsafe[] +function set_unsafe_scope!(tn::AbstractTensorNetwork, uc::Union{Nothing,UnsafeScope}) + TensorNetwork(tn).unsafe[] = uc + return tn +end + """ copy(tn::TensorNetwork) Return a shallow copy of a [`TensorNetwork`](@ref). """ -Base.copy(tn::TensorNetwork) = TensorNetwork(tensors(tn)) +function Base.copy(tn::TensorNetwork) + new_tn = TensorNetwork(tensors(tn); unsafe=get_unsafe_scope(tn)) + + if !isnothing(get_unsafe_scope(tn)) + push!(get_unsafe_scope(tn).refs, WeakRef(new_tn)) # Register the new copy to the proper UnsafeScope + end + + return new_tn +end Base.similar(tn::TensorNetwork) = TensorNetwork(similar.(tensors(tn))) Base.zero(tn::TensorNetwork) = TensorNetwork(zero.(tensors(tn))) @@ -261,20 +280,42 @@ function __check_index_sizes(tn) return true end -const is_unsafe_region = ScopedValue(false) # global ScopedValue for the unsafe region +Base.in(tn::TensorNetwork, uc::UnsafeScope) = tn ∈ values(uc) -macro unsafe_region(tn, block) +macro unsafe_region(tn_sym, block) return esc( quote - local old = copy($tn) + local old = copy($tn_sym) + + # Create a new UnsafeScope and set it to the current tn + local _uc = Tenet.UnsafeScope() + Tenet.set_unsafe_scope!($tn_sym, _uc) + + # Register the tensor network in the UnsafeScope + push!(Tenet.get_unsafe_scope($tn_sym).refs, WeakRef($tn_sym)) + + e = nothing try - $with($is_unsafe_region => true) do - $block - end + $(block) # Execute the user-provided block + catch e + $(tn_sym) = old # Restore the original tensor network in case of an exception + rethrow(e) finally - if !Tenet.__check_index_sizes($tn) - tn = old - throw(DimensionMismatch("Inconsistent size of indices")) + if e === nothing + # Perform checks of registered tensor networks + for ref in Tenet.get_unsafe_scope($tn_sym).refs + tn = ref.value + if tn !== nothing && tn ∈ values(Tenet.get_unsafe_scope($tn_sym)) + if !Tenet.__check_index_sizes(tn) + $(tn_sym) = old + + # Set `unsafe` field to `nothing` + Tenet.set_unsafe_scope!($tn_sym, nothing) + + throw(DimensionMismatch("Inconsistent size of indices")) + end + end + end end end end, @@ -292,8 +333,8 @@ function Base.push!(tn::AbstractTensorNetwork, tensor::Tensor) tn = TensorNetwork(tn) tensor ∈ keys(tn.tensormap) && return tn - # Only check index sizes if we are not in an unsafe region - if !is_unsafe_region[] + # Check index sizes if there isn't an active `UnsafeScope` in the Tensor Network + if isnothing(get_unsafe_scope(tn)) for i in Iterators.filter(i -> size(tn, i) != size(tensor, i), inds(tensor) ∩ inds(tn)) throw( DimensionMismatch("size(tensor,$i)=$(size(tensor,i)) but should be equal to size(tn,$i)=$(size(tn,i))") diff --git a/test/TensorNetwork_test.jl b/test/TensorNetwork_test.jl index 73fbea7ac..2d8026782 100644 --- a/test/TensorNetwork_test.jl +++ b/test/TensorNetwork_test.jl @@ -708,6 +708,50 @@ @test tensors(tn)[1] === a @test tensors(tn)[2] === b end + + @testset "copy inside unsafe region" begin + tn = TensorNetwork([Tensor(ones(2, 2), [:a, :b]), Tensor(ones(2, 2), [:b, :c])]) + + @test_throws DimensionMismatch Tenet.@unsafe_region tn begin + tensor = Tensor(ones(3, 2), [:c, :d]) + push!(tn, tensor) + tn2 = TensorNetwork([Tensor(ones(2, 2), [:a, :b]), Tensor(ones(2, 2), [:b, :c])]) + push!(tn2, tensor) # tn2 is not specified in @unsafe_region argument + @test length(tensors(tn)) == 3 + pop!(tn, tensor) + end + + # Here still errors since at the end `tn2` is inconsistent: + @test_throws DimensionMismatch Tenet.@unsafe_region tn begin + tensor = Tensor(ones(3, 2), [:c, :d]) + push!(tn, tensor) + tn2 = copy(tn) + push!(tn2, tensor) + @test length(tensors(tn)) == 3 + pop!(tn, tensor) + end + + # Double copy should also throw an error: + @test_throws DimensionMismatch Tenet.@unsafe_region tn begin + tensor = Tensor(ones(3, 2), [:c, :d]) + push!(tn, tensor) + tn2 = copy(tn) + tn3 = copy(tn2) + push!(tn3, tensor) + @test length(tensors(tn)) == 3 + pop!(tn, tensor) + end + + Tenet.@unsafe_region tn begin # This should not throw an error + tensor = Tensor(ones(3, 2), [:c, :d]) + push!(tn, tensor) + tn2 = copy(tn) + push!(tn2, tensor) # tn2 is not specified in @unsafe_region + @test length(tensors(tn)) == 3 + pop!(tn, tensor) + pop!(tn2, tensor) + end + end end end