Skip to content

Commit dafc863

Browse files
more tests
1 parent 2915481 commit dafc863

File tree

2 files changed

+54
-40
lines changed

2 files changed

+54
-40
lines changed

src/fillalgebra.jl

+38-40
Original file line numberDiff line numberDiff line change
@@ -84,8 +84,10 @@ end
8484

8585
*(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b)
8686
*(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b)
87+
*(a::Zeros{<:Any,2}, b::AbstractTriangular) = mult_zeros(a, b)
8788
*(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b)
8889
*(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b)
90+
*(a::AbstractTriangular, b::Zeros{<:Any,2}) = mult_zeros(a, b)
8991
*(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b)
9092
*(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b)
9193
*(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b)
@@ -95,65 +97,51 @@ end
9597
*(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
9698
*(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
9799

98-
# Cannot unify following methods for Diagonal
99-
# due to ambiguity with general array mult. with fill
100-
function *(a::Diagonal, b::FillMatrix)
100+
# # Cannot unify following methods for Diagonal
101+
# # due to ambiguity with general array mult. with fill
102+
function *(a::Diagonal, b::AbstractFill{T,2}) where T
101103
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
102104
a.diag .* b # use special broadcast
103105
end
104-
function *(a::FillMatrix, b::Diagonal)
106+
function *(a::AbstractFill{T,2}, b::Diagonal) where T
105107
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
106108
a .* permutedims(b.diag) # use special broadcast
107109
end
108-
function *(a::Diagonal, b::OnesMatrix)
109-
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
110-
a.diag .* b # use special broadcast
111-
end
112-
function *(a::OnesMatrix, b::Diagonal)
113-
size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
114-
a .* permutedims(b.diag) # use special broadcast
115-
end
116-
117-
*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
118-
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
119-
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1))
120110

121-
function *(x::AbstractMatrix, f::FillMatrix)
111+
function mult_sum2(x::AbstractMatrix, f::AbstractFill{T,2}) where T
122112
axes(x, 2) axes(f, 1) &&
123113
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
124114
m = size(f, 2)
125-
repeat(sum(x, dims=2) * f.value, 1, m)
115+
repeat(sum(x, dims=2) * getindex_value(f), 1, m)
126116
end
127117

128-
function *(f::FillMatrix, x::AbstractMatrix)
118+
function mult_sum1(f::AbstractFill{T,2}, x::AbstractMatrix) where T
129119
axes(f, 2) axes(x, 1) &&
130120
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
131121
m = size(f, 1)
132-
repeat(sum(x, dims=1) * f.value, m, 1)
122+
repeat(sum(x, dims=1) * getindex_value(f), m, 1)
133123
end
134124

135-
function *(x::AbstractMatrix, f::OnesMatrix)
136-
axes(x, 2) axes(f, 1) &&
137-
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
138-
m = size(f, 2)
139-
repeat(sum(x, dims=2) * one(eltype(f)), 1, m)
140-
end
125+
*(x::AbstractMatrix, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
126+
*(x::AbstractTriangular, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
141127

142-
function *(f::OnesMatrix, x::AbstractMatrix)
143-
axes(f, 2) axes(x, 1) &&
144-
throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
145-
m = size(f, 1)
146-
repeat(sum(x, dims=1) * one(eltype(f)), m, 1)
147-
end
128+
# *(x::Diagonal, y::AbstractFill{<:Any,2}) = mult_sum2(x, y)
129+
# *(x::Transpose{T,AbstractMatrix{T}}, y::AbstractFill{<:Any,2}) where T = mult_sum2(x, y)
130+
*(x::AbstractFill{<:Any,2}, y::AbstractMatrix) = mult_sum1(x, y)
131+
*(x::AbstractFill{<:Any,2}, y::AbstractTriangular) = mult_sum1(x, y)
148132

149-
*(x::FillMatrix, y::FillMatrix) = mult_fill(x, y)
150-
*(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y)
151-
*(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y)
152-
*(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y)
153-
*(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y)
154-
*(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y)
155-
*(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y)
156-
*(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y)
133+
# *(x::AbstractFill{<:Any,2}, y::Diagonal) = mult_sum1(x, y)
134+
# *(x::AbstractFill{<:Any,2}, y::Transpose{T,AbstractMatrix{T}}) where T = mult_sum1(x, y)
135+
136+
137+
# *(x::FillMatrix, y::FillMatrix) = mult_fill(x, y)
138+
# *(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y)
139+
# *(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y)
140+
# *(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y)
141+
# *(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y)
142+
# *(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y)
143+
# *(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y)
144+
# *(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y)
157145

158146
# function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
159147
# fB = similar(parent(a), size(b, 1), size(b, 2))
@@ -173,6 +161,16 @@ end
173161
# return a*fB
174162
# end
175163

164+
## Matrix-Vector multiplication
165+
166+
*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T =
167+
reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2))
168+
*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T =
169+
reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2))
170+
*(a::StridedMatrix{T}, b::Fill{T, 1}) where T =
171+
reshape(sum(a; dims=2) .* b.value, size(a, 1))
172+
173+
176174
function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
177175
la, lb = length(a), length(b)
178176
if la lb

test/runtests.jl

+16
Original file line numberDiff line numberDiff line change
@@ -1015,6 +1015,9 @@ end
10151015
@test D*Zeros(1) Zeros(1)
10161016

10171017
D = Diagonal(Fill(2,10))
1018+
# @show D * Ones(10)
1019+
# @show D * Ones(10,5)
1020+
# @show Ones(5,10) * D
10181021
@test D * Ones(10) Fill(2.0,10)
10191022
@test D * Ones(10,5) Fill(2.0,10,5)
10201023
@test Ones(5,10) * D Fill(2.0,5,10)
@@ -1028,6 +1031,19 @@ end
10281031
@test E*(1:5) 1.0:5.0
10291032
@test (1:5)'E == (1.0:5)'
10301033
@test E*E E
1034+
1035+
# Adjoint / Transpose / Triangular / Symmetric
1036+
for x in [transpose(rand(2, 2)),
1037+
adjoint(rand(2,2)),
1038+
UpperTriangular(rand(2,2)),
1039+
Symmetric(rand(2,2))]
1040+
@test x * Ones(2, 2) isa Matrix
1041+
@test Ones(2, 2) * x isa Matrix
1042+
@test x * Zeros(2, 2) isa Zeros
1043+
@test Zeros(2, 2) * x isa Zeros
1044+
@test x * Fill(1., 2, 2) isa Matrix
1045+
@test Fill(1., 2, 2) * x isa Matrix
1046+
end
10311047
end
10321048

10331049
@testset "count" begin

0 commit comments

Comments
 (0)