From b50d0f0361214117c58fb43f566b878e9cf45b89 Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Wed, 13 Nov 2024 16:13:40 +0100 Subject: [PATCH 1/3] Fix slice finder so it now uses atol properly --- src/Transformations.jl | 9 ++++++--- 1 file changed, 6 insertions(+), 3 deletions(-) diff --git a/src/Transformations.jl b/src/Transformations.jl index 779a7eefa..f49c7fb88 100644 --- a/src/Transformations.jl +++ b/src/Transformations.jl @@ -179,10 +179,13 @@ function transform!(tn::TensorNetwork, config::Truncate) for (dim, index) in enumerate(inds(tensor)) index ∈ skip_inds && continue - zeroslices = iszero.(eachslice(tensor; dims=dim)) - any(zeroslices) || continue + # Use atol to determine small slices + small_slices = [maximum(abs.(s)) < config.atol for s in eachslice(tensor; dims=dim)] + any(small_slices) || continue - slice!(tn, index, count(!, zeroslices) == 1 ? findfirst(!, zeroslices) : findall(!, zeroslices)) + # Keep slices where the maximum absolute value is greater than or equal to atol + slices_to_keep = count(!, small_slices) == 1 ? findfirst(!, small_slices) : findall(!, small_slices) + slice!(tn, index, slices_to_keep) end end From aa2e47bc1544660617b0733ca2457b69b163145d Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Wed, 13 Nov 2024 16:14:08 +0100 Subject: [PATCH 2/3] Enhance tests to cover the proper behavior of atol --- test/Transformations_test.jl | 30 ++++++++++++++++++++++++++---- 1 file changed, 26 insertions(+), 4 deletions(-) diff --git a/test/Transformations_test.jl b/test/Transformations_test.jl index 7e299ace3..adec32972 100644 --- a/test/Transformations_test.jl +++ b/test/Transformations_test.jl @@ -232,33 +232,55 @@ using Tenet: Truncate @testset "range" begin + # Create tensor data with small values less than the default atol (1e-12) data = rand(3, 3, 3) - data[:, 1:2, :] .= 0 + data[:, 1:2, :] .= 1e-13 A = Tensor(data, (:i, :j, :k)) B = Tensor(rand(3, 3), (:j, :l)) C = Tensor(rand(3, 3), (:j, :m)) tn = TensorNetwork([A, B, C]) - reduced = transform(tn, Truncate) + # Apply Truncate transformation with default atol (1e-12) + reduced = transform(tn, Truncate()) + + # Test that index :j is removed because all its slices are below atol @test :j ∉ inds(reduced) @test contract(reduced) ≈ contract(tn) + + # Now, apply Truncate with a smaller atol (1e-14) so slices are not truncated + reduced_no_trunc = transform(tn, Truncate(atol=1e-14)) + + # Test that index :j is still present + @test :j ∈ inds(reduced_no_trunc) + @test contract(reduced_no_trunc) ≈ contract(tn) end @testset "int" begin + # Create tensor data with one slice having small values less than default atol data = rand(3, 3, 3) - data[:, 2, :] .= 0 + data[:, 2, :] .= 1e-13 A = Tensor(data, (:i, :j, :k)) B = Tensor(rand(3, 3), (:j, :l)) C = Tensor(rand(3, 3), (:j, :m)) tn = TensorNetwork([A, B, C]) - reduced = transform(tn, Truncate) + # Apply Truncate transformation with default atol (1e-12) + reduced = transform(tn, Truncate()) + + # Test that size of index :j is reduced by 1 @test size(reduced, :j) == 2 @test contract(reduced) ≈ contract(tn) + + # Now, apply Truncate with a smaller atol (1e-14) so the slice is not truncated + reduced_no_trunc = transform(tn, Truncate(atol=1e-14)) + + # Test that size of index :j remains the same + @test size(reduced_no_trunc, :j) == 3 + @test contract(reduced_no_trunc) ≈ contract(tn) end end From 4630c5e474788b28c25c7eafad2ef9dbd13398ca Mon Sep 17 00:00:00 2001 From: jofrevalles Date: Wed, 13 Nov 2024 16:22:55 +0100 Subject: [PATCH 3/3] Fix format --- test/Transformations_test.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/test/Transformations_test.jl b/test/Transformations_test.jl index adec32972..0864c4188 100644 --- a/test/Transformations_test.jl +++ b/test/Transformations_test.jl @@ -250,7 +250,7 @@ @test contract(reduced) ≈ contract(tn) # Now, apply Truncate with a smaller atol (1e-14) so slices are not truncated - reduced_no_trunc = transform(tn, Truncate(atol=1e-14)) + reduced_no_trunc = transform(tn, Truncate(; atol=1e-14)) # Test that index :j is still present @test :j ∈ inds(reduced_no_trunc) @@ -276,7 +276,7 @@ @test contract(reduced) ≈ contract(tn) # Now, apply Truncate with a smaller atol (1e-14) so the slice is not truncated - reduced_no_trunc = transform(tn, Truncate(atol=1e-14)) + reduced_no_trunc = transform(tn, Truncate(; atol=1e-14)) # Test that size of index :j remains the same @test size(reduced_no_trunc, :j) == 3