|
84 | 84 |
|
85 | 85 | *(a::Zeros{<:Any,1}, b::AbstractMatrix) = mult_zeros(a, b)
|
86 | 86 | *(a::Zeros{<:Any,2}, b::AbstractMatrix) = mult_zeros(a, b)
|
| 87 | +*(a::Zeros{<:Any,2}, b::AbstractTriangular) = mult_zeros(a, b) |
87 | 88 | *(a::AbstractMatrix, b::Zeros{<:Any,1}) = mult_zeros(a, b)
|
88 | 89 | *(a::AbstractMatrix, b::Zeros{<:Any,2}) = mult_zeros(a, b)
|
| 90 | +*(a::AbstractTriangular, b::Zeros{<:Any,2}) = mult_zeros(a, b) |
89 | 91 | *(a::Zeros{<:Any,1}, b::AbstractVector) = mult_zeros(a, b)
|
90 | 92 | *(a::Zeros{<:Any,2}, b::AbstractVector) = mult_zeros(a, b)
|
91 | 93 | *(a::AbstractVector, b::Zeros{<:Any,2}) = mult_zeros(a, b)
|
|
95 | 97 | *(a::Diagonal, b::Zeros{<:Any,1}) = mult_zeros(a, b)
|
96 | 98 | *(a::Diagonal, b::Zeros{<:Any,2}) = mult_zeros(a, b)
|
97 | 99 |
|
98 |
| -# Cannot unify following methods for Diagonal |
99 |
| -# due to ambiguity with general array mult. with fill |
100 |
| -function *(a::Diagonal, b::FillMatrix) |
| 100 | +# # Cannot unify following methods for Diagonal |
| 101 | +# # due to ambiguity with general array mult. with fill |
| 102 | +function *(a::Diagonal, b::AbstractFill{T,2}) where T |
101 | 103 | size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
|
102 | 104 | a.diag .* b # use special broadcast
|
103 | 105 | end
|
104 |
| -function *(a::FillMatrix, b::Diagonal) |
| 106 | +function *(a::AbstractFill{T,2}, b::Diagonal) where T |
105 | 107 | size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))"))
|
106 | 108 | a .* permutedims(b.diag) # use special broadcast
|
107 | 109 | end
|
108 |
| -function *(a::Diagonal, b::OnesMatrix) |
109 |
| - size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) |
110 |
| - a.diag .* b # use special broadcast |
111 |
| -end |
112 |
| -function *(a::OnesMatrix, b::Diagonal) |
113 |
| - size(a,2) == size(b,1) || throw(DimensionMismatch("A has dimensions $(size(a)) but B has dimensions $(size(b))")) |
114 |
| - a .* permutedims(b.diag) # use special broadcast |
115 |
| -end |
116 |
| - |
117 |
| -*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2)) |
118 |
| -*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2)) |
119 |
| -*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = reshape(sum(a; dims=2) .* b.value, size(a, 1)) |
120 | 110 |
|
121 |
| -function *(x::AbstractMatrix, f::FillMatrix) |
| 111 | +function mult_sum2(x::AbstractMatrix, f::AbstractFill{T,2}) where T |
122 | 112 | axes(x, 2) ≠ axes(f, 1) &&
|
123 | 113 | throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
|
124 | 114 | m = size(f, 2)
|
125 |
| - repeat(sum(x, dims=2) * f.value, 1, m) |
| 115 | + repeat(sum(x, dims=2) * getindex_value(f), 1, m) |
126 | 116 | end
|
127 | 117 |
|
128 |
| -function *(f::FillMatrix, x::AbstractMatrix) |
| 118 | +function mult_sum1(f::AbstractFill{T,2}, x::AbstractMatrix) where T |
129 | 119 | axes(f, 2) ≠ axes(x, 1) &&
|
130 | 120 | throw(DimensionMismatch("Incompatible matrix multiplication dimensions"))
|
131 | 121 | m = size(f, 1)
|
132 |
| - repeat(sum(x, dims=1) * f.value, m, 1) |
| 122 | + repeat(sum(x, dims=1) * getindex_value(f), m, 1) |
133 | 123 | end
|
134 | 124 |
|
135 |
| -function *(x::AbstractMatrix, f::OnesMatrix) |
136 |
| - axes(x, 2) ≠ axes(f, 1) && |
137 |
| - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) |
138 |
| - m = size(f, 2) |
139 |
| - repeat(sum(x, dims=2) * one(eltype(f)), 1, m) |
140 |
| -end |
| 125 | +*(x::AbstractMatrix, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) |
| 126 | +*(x::AbstractTriangular, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) |
141 | 127 |
|
142 |
| -function *(f::OnesMatrix, x::AbstractMatrix) |
143 |
| - axes(f, 2) ≠ axes(x, 1) && |
144 |
| - throw(DimensionMismatch("Incompatible matrix multiplication dimensions")) |
145 |
| - m = size(f, 1) |
146 |
| - repeat(sum(x, dims=1) * one(eltype(f)), m, 1) |
147 |
| -end |
| 128 | +# *(x::Diagonal, y::AbstractFill{<:Any,2}) = mult_sum2(x, y) |
| 129 | +# *(x::Transpose{T,AbstractMatrix{T}}, y::AbstractFill{<:Any,2}) where T = mult_sum2(x, y) |
| 130 | +*(x::AbstractFill{<:Any,2}, y::AbstractMatrix) = mult_sum1(x, y) |
| 131 | +*(x::AbstractFill{<:Any,2}, y::AbstractTriangular) = mult_sum1(x, y) |
148 | 132 |
|
149 |
| -*(x::FillMatrix, y::FillMatrix) = mult_fill(x, y) |
150 |
| -*(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y) |
151 |
| -*(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y) |
152 |
| -*(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y) |
153 |
| -*(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y) |
154 |
| -*(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y) |
155 |
| -*(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y) |
156 |
| -*(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y) |
| 133 | +# *(x::AbstractFill{<:Any,2}, y::Diagonal) = mult_sum1(x, y) |
| 134 | +# *(x::AbstractFill{<:Any,2}, y::Transpose{T,AbstractMatrix{T}}) where T = mult_sum1(x, y) |
| 135 | + |
| 136 | + |
| 137 | +# *(x::FillMatrix, y::FillMatrix) = mult_fill(x, y) |
| 138 | +# *(x::FillMatrix, y::OnesMatrix) = mult_fill(x, y) |
| 139 | +# *(x::OnesMatrix, y::FillMatrix) = mult_fill(x, y) |
| 140 | +# *(x::OnesMatrix, y::OnesMatrix) = mult_fill(x, y) |
| 141 | +# *(x::ZerosMatrix, y::OnesMatrix) = mult_zeros(x, y) |
| 142 | +# *(x::ZerosMatrix, y::FillMatrix) = mult_zeros(x, y) |
| 143 | +# *(x::FillMatrix, y::ZerosMatrix) = mult_zeros(x, y) |
| 144 | +# *(x::OnesMatrix, y::ZerosMatrix) = mult_zeros(x, y) |
157 | 145 |
|
158 | 146 | # function *(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 2}) where T
|
159 | 147 | # fB = similar(parent(a), size(b, 1), size(b, 2))
|
|
173 | 161 | # return a*fB
|
174 | 162 | # end
|
175 | 163 |
|
| 164 | +## Matrix-Vector multiplication |
| 165 | + |
| 166 | +*(a::Adjoint{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = |
| 167 | + reshape(sum(conj.(parent(a)); dims=1) .* b.value, size(parent(a), 2)) |
| 168 | +*(a::Transpose{T, <:StridedMatrix{T}}, b::Fill{T, 1}) where T = |
| 169 | + reshape(sum(parent(a); dims=1) .* b.value, size(parent(a), 2)) |
| 170 | +*(a::StridedMatrix{T}, b::Fill{T, 1}) where T = |
| 171 | + reshape(sum(a; dims=2) .* b.value, size(a, 1)) |
| 172 | + |
| 173 | + |
176 | 174 | function _adjvec_mul_zeros(a::Adjoint{T}, b::Zeros{S, 1}) where {T, S}
|
177 | 175 | la, lb = length(a), length(b)
|
178 | 176 | if la ≠ lb
|
|
0 commit comments