Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Refactor @unsafe_region macro #233

Merged
merged 14 commits into from
Nov 12, 2024
108 changes: 95 additions & 13 deletions src/TensorNetwork.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ struct TensorNetwork <: AbstractTensorNetwork
tensormap::IdDict{Tensor,Vector{Symbol}}

sorted_tensors::CachedField{Vector{Tensor}}
check_index_sizes::Ref{Bool}
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved

function TensorNetwork(tensors)
tensormap = IdDict{Tensor,Vector{Symbol}}(tensor => inds(tensor) for tensor in tensors)
Expand All @@ -59,7 +60,29 @@ struct TensorNetwork <: AbstractTensorNetwork
length(unique(dims)) == 1 || throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)"))
end

return new(indexmap, tensormap, CachedField{Vector{Tensor}}())
return new(indexmap, tensormap, CachedField{Vector{Tensor}}(), Ref(true))
end

function TensorNetwork(tensors, check_index_sizes::Bool)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
tensormap = IdDict{Tensor,Vector{Symbol}}(tensor => inds(tensor) for tensor in tensors)

indexmap = reduce(tensors; init=Dict{Symbol,Vector{Tensor}}()) do dict, tensor
for index in inds(tensor)
# TODO use lambda? `Tensor[]` might be reused
push!(get!(dict, index, Tensor[]), tensor)
end
dict
end

if check_index_sizes
for ind in keys(indexmap)
dims = map(tensor -> size(tensor, ind), indexmap[ind])
length(unique(dims)) == 1 ||
throw(DimensionMismatch("Index $(ind) has inconsistent dimension: $(dims)"))
end
end

return new(indexmap, tensormap, CachedField{Vector{Tensor}}(), Ref(check_index_sizes))
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
end
end

Expand All @@ -71,7 +94,17 @@ TensorNetwork(tn::TensorNetwork) = tn

Return a shallow copy of a [`TensorNetwork`](@ref).
"""
Base.copy(tn::TensorNetwork) = TensorNetwork(tensors(tn))
function Base.copy(tn::TensorNetwork)
new_tn = TensorNetwork(tensors(tn), tn.check_index_sizes[])
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved

# Check if there's an active UnsafeContext with tn in it
uc = current_unsafe_context()
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
if uc !== nothing && tn ∈ uc
push!(uc.refs, WeakRef(new_tn)) # Register the new copy in the UnsafeContext
end

return new_tn
end

Base.similar(tn::TensorNetwork) = TensorNetwork(similar.(tensors(tn)))
Base.zero(tn::TensorNetwork) = TensorNetwork(zero.(tensors(tn)))
Expand Down Expand Up @@ -261,20 +294,68 @@ function __check_index_sizes(tn)
return true
end

const is_unsafe_region = ScopedValue(false) # global ScopedValue for the unsafe region
struct UnsafeContext
refs::Vector{WeakRef} # List of weak references

UnsafeContext() = new(Vector{WeakRef}())
end

Base.values(uc::UnsafeContext) = map(x -> x.value, uc.refs)

# Global stack to manage nested unsafe contexts
const _unsafe_context_stack = Ref{Vector{UnsafeContext}}(Vector{UnsafeContext}())
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mmm you don't need this? the idea was that @unsafe_region first creates a UnsafeContext and passes that the TensorNetwork

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You need this to check if there is an active UnsafeContext when you call copy and related functions.


Base.in(tn::TensorNetwork, tnstack::Vector{TensorNetwork}) = any(tn_ -> tn === tn_, tnstack)
jofrevalles marked this conversation as resolved.
Show resolved Hide resolved
Base.in(tn::TensorNetwork, ucstack::Vector{UnsafeContext}) = any(uc -> tn ∈ values(uc), ucstack)
Base.in(tn::TensorNetwork, uc::UnsafeContext) = tn ∈ values(uc)

# Function to get the current UnsafeContext
function current_unsafe_context()
if isempty(Tenet._unsafe_context_stack[])
return nothing
else
return Tenet._unsafe_context_stack[][end]
end
end

macro unsafe_region(tn, block)
# Define the @unsafe_region macro
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this comment is a lil bit auto-descriptive? maybe we can just remove it or add a docstring

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Maybe its better to remove it. I prefer not to have a docstring here since this is an internal operation, for now.

macro unsafe_region(tn_sym, block)
return esc(
quote
local old = copy($tn)
local old = copy($tn_sym)

# Create a new UnsafeContext and push it onto the stack
local _uc = Tenet.UnsafeContext()
push!(Tenet._unsafe_context_stack[], _uc)

# Set check_index_sizes to false for the passed tensor network
$tn_sym.check_index_sizes[] = false

# Register the tensor network in the context
push!(_uc.refs, WeakRef($tn_sym))

e = nothing
try
$with($is_unsafe_region => true) do
$block
end
$(block) # Execute the user-provided block
catch e
$(tn_sym) = old # Restore the original tensor network in case of an exception
rethrow(e)
finally
if !Tenet.__check_index_sizes($tn)
tn = old
throw(DimensionMismatch("Inconsistent size of indices"))
if e === nothing
# Perform checks of registered tensor networks
for ref in _uc.refs
tn = ref.value
if tn !== nothing
if !Tenet.__check_index_sizes(tn)
$(tn_sym) = old
pop!(Tenet._unsafe_context_stack[])
throw(DimensionMismatch("Inconsistent size of indices"))
end
end
end

# Pop the UnsafeContext from the stack
pop!(Tenet._unsafe_context_stack[])
end
end
end,
Expand All @@ -292,8 +373,9 @@ function Base.push!(tn::AbstractTensorNetwork, tensor::Tensor)
tn = TensorNetwork(tn)
tensor ∈ keys(tn.tensormap) && return tn

# Only check index sizes if we are not in an unsafe region
if !is_unsafe_region[]
# Check if there's an active UnsafeContext with tn in it
uc = current_unsafe_context()
if uc === nothing || tn ∉ uc # Only check index sizes if we are not in an unsafe region
for i in Iterators.filter(i -> size(tn, i) != size(tensor, i), inds(tensor) ∩ inds(tn))
throw(
DimensionMismatch("size(tensor,$i)=$(size(tensor,i)) but should be equal to size(tn,$i)=$(size(tn,i))")
Expand Down
Loading