Skip to content

Commit

Permalink
Enhance qr function with mode argument (#114)
Browse files Browse the repository at this point in the history
* Add mode argument into qr function

* Remove mode argument, default to reduced QR

* Fix test

* Remove unnecessary argument
  • Loading branch information
jofrevalles authored Oct 31, 2023
1 parent e7e5fae commit 2089a76
Show file tree
Hide file tree
Showing 2 changed files with 14 additions and 4 deletions.
16 changes: 13 additions & 3 deletions src/Numerics.jl
Original file line number Diff line number Diff line change
Expand Up @@ -119,8 +119,17 @@ end

LinearAlgebra.qr(t::Tensor{<:Any,2}; kwargs...) = Base.@invoke qr(t::Tensor; left_inds = (first(inds(t)),), kwargs...)

function LinearAlgebra.qr(t::Tensor; left_inds = (), right_inds = (), virtualind::Symbol = Symbol(uuid4()), kwargs...)
isdisjoint(left_inds, right_inds) ||
"""
LinearAlgebra.qr(t::Tensor, mode::Symbol = :reduced; left_inds = (), right_inds = (), virtualind::Symbol = Symbol(uuid4()), kwargs...
Perform QR factorization on a tensor.
# Arguments
- `t::Tensor`: tensor to be factorized
# Keyword Arguments
- `left_inds`: left indices to be used in the QR factorization. Defaults to all indices of `t` except `right_inds`.
- `right_inds`: right indices to be used in the QR factorization. Defaults to all indices of `t` except `left_inds`.
- `virtualind`: name of the virtual bond. Defaults to a random `Symbol`.
"""
function LinearAlgebra.qr(t::Tensor; left_inds = (), right_inds = (), virtualind::Symbol = Symbol(uuid4()), kwargs...) isdisjoint(left_inds, right_inds) ||
throw(ArgumentError("left ($left_inds) and right $(right_inds) indices must be disjoint"))

left_inds, right_inds =
Expand All @@ -138,7 +147,8 @@ function LinearAlgebra.qr(t::Tensor; left_inds = (), right_inds = (), virtualind
data = reshape(parent(tensor), prod(i -> size(t, i), left_inds), prod(i -> size(t, i), right_inds))

# compute QR
Q, R = qr(data; kwargs...)
F = qr(data; kwargs...)
Q, R = Matrix(F.Q), Matrix(F.R)

# tensorify results
Q = reshape(Q, ([size(t, ind) for ind in left_inds]..., size(Q, 2)))
Expand Down
2 changes: 1 addition & 1 deletion test/Numerics_test.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@
@testset "size" begin
Q, R = qr(tensor, left_inds = (:i, :j))
# Q's new index size = min(prod(left_inds), prod(right_inds)).
@test size(Q) == (2, 2, 4)
@test size(Q) == (2, 2, 2)
@test size(R) == (2, 2)

# Additional test with different dimensions
Expand Down

0 comments on commit 2089a76

Please sign in to comment.