Skip to content

Commit

Permalink
Replace legacy labels for inds function
Browse files Browse the repository at this point in the history
  • Loading branch information
jofrevalles committed Sep 18, 2023
1 parent bb5845e commit 4a2a6f2
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 15 deletions.
4 changes: 2 additions & 2 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -158,9 +158,9 @@ LinearAlgebra.lu(t::Tensor; left_inds=(), kwargs...) = lu(t, left_inds; kwargs..
function LinearAlgebra.lu(t::Tensor, left_inds; kwargs...)
# TODO better error exception and checks
isempty(left_inds) && throw(ErrorException("no left-indices in LU factorization"))
left_inds labels(t) || throw(ErrorException("all left-indices must be in $(labels(t))"))
left_inds inds(t) || throw(ErrorException("all left-indices must be in $(inds(t))"))

right_inds = setdiff(labels(t), left_inds)
right_inds = setdiff(inds(t), left_inds)
isempty(right_inds) && throw(ErrorException("no right-indices in LU factorization"))

# permute array
Expand Down
28 changes: 15 additions & 13 deletions test/Numerics_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -222,50 +222,52 @@
end

@testset "lu" begin
using LinearAlgebra: lu

data = rand(2, 2, 2)
tensor = Tensor(data, (:i, :j, :k))

@testset "[exceptions]" begin
# Throw exception if left_inds is not provided
@test_throws ErrorException lu(tensor)
# Throw exception if left_inds ∉ labels(tensor)
# Throw exception if left_inds ∉ inds(tensor)
@test_throws ErrorException lu(tensor, (:l,))
# throw exception if no right-inds
@test_throws ErrorException lu(tensor, (:i,:j,:k))
end

@testset "labels" begin
P, L, U = lu(tensor, labels(tensor)[1:2])
@test labels(P)[1:2] == labels(tensor)[1:2]
@test labels(P)[3:4] == labels(L)[1:2]
@test labels(L)[3] == labels(U)[1]
@test labels(U)[2] == labels(tensor)[3]
@testset "inds" begin
P, L, U = lu(tensor, inds(tensor)[1:2])
@test inds(P)[1:2] == inds(tensor)[1:2]
@test inds(P)[3:4] == inds(L)[1:2]
@test inds(L)[3] == inds(U)[1]
@test inds(U)[2] == inds(tensor)[3]
end

@testset "size" begin
P, L, U = lu(tensor, labels(tensor)[1:2])
P, L, U = lu(tensor, inds(tensor)[1:2])
@test size(P) == (2, 2, 2, 2)
@test size(L) == (2, 2, 2)
@test size(U) == (2, 2)

# Additional test with different dimensions
data2 = rand(2, 4, 6, 8)
tensor2 = Tensor(data2, (:i, :j, :k, :l))
P2, L2, U2 = lu(tensor2, labels(tensor2)[1:2])
P2, L2, U2 = lu(tensor2, inds(tensor2)[1:2])
@test size(P2) == (2, 4, 2, 4)
@test size(L2) == (2, 4, 8)
@test size(U2) == (8, 6, 8)
end

@testset "[accuracy]" begin
P, L, U = lu(tensor, labels(tensor)[1:2])
tensor_recovered = contract(contract(P, L), U)
P, L, U = lu(tensor, inds(tensor)[1:2])
tensor_recovered = contract(P, L, U)
@test tensor_recovered tensor

data2 = rand(2, 4, 6, 8)
tensor2 = Tensor(data2, (:i, :j, :k, :l))
P2, L2, U2 = lu(tensor2, labels(tensor2)[1:2])
tensor2_recovered = contract(contract(P2, L2), U2)
P2, L2, U2 = lu(tensor2, inds(tensor2)[1:2])
tensor2_recovered = contract(P2, L2, U2)
@test tensor2_recovered tensor2
end
end
Expand Down

0 comments on commit 4a2a6f2

Please sign in to comment.