Skip to content

Commit

Permalink
Refactor @unsafe_region macro (#233)
Browse files Browse the repository at this point in the history
  • Loading branch information
jofrevalles authored Nov 12, 2024
1 parent 6e3e875 commit 6acfe1c
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 18 deletions.
8 changes: 8 additions & 0 deletions src/Helpers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,3 +49,11 @@ resetindex!(gen::IndexCounter) = letter(Threads.atomic_xchg!(gen.counter, 1))
# if is Complex, extract the parametric type and get the eps of that
wrap_eps(x) = eps(x)
wrap_eps(::Type{Complex{T}}) where {T} = eps(T)

struct UnsafeScope
refs::Vector{WeakRef}

UnsafeScope() = new(Vector{WeakRef}())
end

Base.values(uc::UnsafeScope) = map(x -> x.value, uc.refs)
77 changes: 59 additions & 18 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,8 +41,10 @@ struct TensorNetwork <: AbstractTensorNetwork
tensormap::IdDict{Tensor,Vector{Symbol}}

sorted_tensors::CachedField{Vector{Tensor}}
unsafe::Ref{Union{Nothing,UnsafeScope}}

function TensorNetwork(tensors)
# TODO: Find a way to remove the `unsafe` keyword argument from the constructor
function TensorNetwork(tensors; unsafe::Union{Nothing,UnsafeScope}=nothing)
tensormap = IdDict{Tensor,Vector{Symbol}}(tensor => inds(tensor) for tensor in tensors)

indexmap = reduce(tensors; init=Dict{Symbol,Vector{Tensor}}()) do dict, tensor
Expand All @@ -53,25 +55,42 @@ struct TensorNetwork <: AbstractTensorNetwork
dict
end

# Check for inconsistent index dimensions
for ind in keys(indexmap)
dims = map(tensor -> size(tensor, ind), indexmap[ind])
length(unique(dims)) == 1 || throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)"))
# Check index size consistency if not inside an `UnsafeScope`
if isnothing(unsafe)
for ind in keys(indexmap)
dims = map(tensor -> size(tensor, ind), indexmap[ind])
length(unique(dims)) == 1 ||
throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)"))
end
end

return new(indexmap, tensormap, CachedField{Vector{Tensor}}())
return new(indexmap, tensormap, CachedField{Vector{Tensor}}(), Ref{Union{Nothing,UnsafeScope}}(unsafe))
end
end

TensorNetwork() = TensorNetwork(Tensor[])
TensorNetwork(tn::TensorNetwork) = tn

get_unsafe_scope(tn::AbstractTensorNetwork) = TensorNetwork(tn).unsafe[]
function set_unsafe_scope!(tn::AbstractTensorNetwork, uc::Union{Nothing,UnsafeScope})
TensorNetwork(tn).unsafe[] = uc
return tn
end

"""
copy(tn::TensorNetwork)
Return a shallow copy of a [`TensorNetwork`](@ref).
"""
Base.copy(tn::TensorNetwork) = TensorNetwork(tensors(tn))
function Base.copy(tn::TensorNetwork)
new_tn = TensorNetwork(tensors(tn); unsafe=get_unsafe_scope(tn))

if !isnothing(get_unsafe_scope(tn))
push!(get_unsafe_scope(tn).refs, WeakRef(new_tn)) # Register the new copy to the proper UnsafeScope
end

return new_tn
end

Base.similar(tn::TensorNetwork) = TensorNetwork(similar.(tensors(tn)))
Base.zero(tn::TensorNetwork) = TensorNetwork(zero.(tensors(tn)))
Expand Down Expand Up @@ -261,20 +280,42 @@ function __check_index_sizes(tn)
return true
end

const is_unsafe_region = ScopedValue(false) # global ScopedValue for the unsafe region
Base.in(tn::TensorNetwork, uc::UnsafeScope) = tn values(uc)

macro unsafe_region(tn, block)
macro unsafe_region(tn_sym, block)
return esc(
quote
local old = copy($tn)
local old = copy($tn_sym)

# Create a new UnsafeScope and set it to the current tn
local _uc = Tenet.UnsafeScope()
Tenet.set_unsafe_scope!($tn_sym, _uc)

# Register the tensor network in the UnsafeScope
push!(Tenet.get_unsafe_scope($tn_sym).refs, WeakRef($tn_sym))

e = nothing
try
$with($is_unsafe_region => true) do
$block
end
$(block) # Execute the user-provided block
catch e
$(tn_sym) = old # Restore the original tensor network in case of an exception
rethrow(e)
finally
if !Tenet.__check_index_sizes($tn)
tn = old
throw(DimensionMismatch("Inconsistent size of indices"))
if e === nothing
# Perform checks of registered tensor networks
for ref in Tenet.get_unsafe_scope($tn_sym).refs
tn = ref.value
if tn !== nothing && tn values(Tenet.get_unsafe_scope($tn_sym))
if !Tenet.__check_index_sizes(tn)
$(tn_sym) = old

# Set `unsafe` field to `nothing`
Tenet.set_unsafe_scope!($tn_sym, nothing)

throw(DimensionMismatch("Inconsistent size of indices"))
end
end
end
end
end
end,
Expand All @@ -292,8 +333,8 @@ function Base.push!(tn::AbstractTensorNetwork, tensor::Tensor)
tn = TensorNetwork(tn)
tensor keys(tn.tensormap) && return tn

# Only check index sizes if we are not in an unsafe region
if !is_unsafe_region[]
# Check index sizes if there isn't an active `UnsafeScope` in the Tensor Network
if isnothing(get_unsafe_scope(tn))
for i in Iterators.filter(i -> size(tn, i) != size(tensor, i), inds(tensor) inds(tn))
throw(
DimensionMismatch("size(tensor,$i)=$(size(tensor,i)) but should be equal to size(tn,$i)=$(size(tn,i))")
Expand Down
44 changes: 44 additions & 0 deletions test/TensorNetwork_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -708,6 +708,50 @@
@test tensors(tn)[1] === a
@test tensors(tn)[2] === b
end

@testset "copy inside unsafe region" begin
tn = TensorNetwork([Tensor(ones(2, 2), [:a, :b]), Tensor(ones(2, 2), [:b, :c])])

@test_throws DimensionMismatch Tenet.@unsafe_region tn begin
tensor = Tensor(ones(3, 2), [:c, :d])
push!(tn, tensor)
tn2 = TensorNetwork([Tensor(ones(2, 2), [:a, :b]), Tensor(ones(2, 2), [:b, :c])])
push!(tn2, tensor) # tn2 is not specified in @unsafe_region argument
@test length(tensors(tn)) == 3
pop!(tn, tensor)
end

# Here still errors since at the end `tn2` is inconsistent:
@test_throws DimensionMismatch Tenet.@unsafe_region tn begin
tensor = Tensor(ones(3, 2), [:c, :d])
push!(tn, tensor)
tn2 = copy(tn)
push!(tn2, tensor)
@test length(tensors(tn)) == 3
pop!(tn, tensor)
end

# Double copy should also throw an error:
@test_throws DimensionMismatch Tenet.@unsafe_region tn begin
tensor = Tensor(ones(3, 2), [:c, :d])
push!(tn, tensor)
tn2 = copy(tn)
tn3 = copy(tn2)
push!(tn3, tensor)
@test length(tensors(tn)) == 3
pop!(tn, tensor)
end

Tenet.@unsafe_region tn begin # This should not throw an error
tensor = Tensor(ones(3, 2), [:c, :d])
push!(tn, tensor)
tn2 = copy(tn)
push!(tn2, tensor) # tn2 is not specified in @unsafe_region
@test length(tensors(tn)) == 3
pop!(tn, tensor)
pop!(tn2, tensor)
end
end
end
end

Expand Down

0 comments on commit 6acfe1c

Please sign in to comment.