Skip to content

Commit 26dde0b

Browse files
authored
Enforce usage of similar_type for matrix multiplication (#143)
* Enforce usage of `similar_type` for matrix multiplication * Fix related tests
1 parent e6d7387 commit 26dde0b

File tree

2 files changed

+11
-11
lines changed

2 files changed

+11
-11
lines changed

src/matrix_multiply.jl

+7-7
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,7 @@ end
102102

103103
if s == sb
104104
if T == Tb
105-
newtype = b
105+
newtype = similar_type(b)
106106
else
107107
newtype = similar_type(b, T)
108108
end
@@ -141,7 +141,7 @@ end
141141

142142
if s == sb
143143
if T == Tb
144-
newtype = b
144+
newtype = similar_type(b)
145145
else
146146
newtype = similar_type(b, T)
147147
end
@@ -237,7 +237,7 @@ end
237237

238238
if s == sB
239239
if T == TB
240-
newtype = B
240+
newtype = similar_type(B)
241241
else
242242
newtype = similar_type(B, T)
243243
end
@@ -334,7 +334,7 @@ end
334334
# TODO think about which to be similar to
335335
if s == sB
336336
if T == TB
337-
newtype = B
337+
newtype = similar_type(B)
338338
else
339339
newtype = similar_type(B, T)
340340
end
@@ -375,7 +375,7 @@ end
375375
# TODO think about which to be similar to
376376
if s == sB
377377
if T == TB
378-
newtype = B
378+
newtype = similar_type(B)
379379
else
380380
newtype = similar_type(B, T)
381381
end
@@ -420,7 +420,7 @@ end
420420
# TODO think about which to be similar to
421421
if s == sB
422422
if T == TB
423-
newtype = B
423+
newtype = similar_type(B)
424424
else
425425
newtype = similar_type(B, T)
426426
end
@@ -463,7 +463,7 @@ end
463463

464464
if s == sb
465465
if T == Tb
466-
newtype = b
466+
newtype = similar_type(b)
467467
else
468468
newtype = similar_type(b, T)
469469
end

test/matrix_multiply.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,11 @@
1818

1919
m3 = @SArray [1 2; 3 4]
2020
v5 = @SArray [1, 2]
21-
@test m3*v5 === @SArray [5, 11]
21+
@test m3*v5 === @SVector [5, 11]
2222

2323
m4 = @MArray [1 2; 3 4]
2424
v6 = @MArray [1, 2]
25-
@test (m4*v6)::MArray == @MArray [5, 11]
25+
@test (m4*v6)::MVector == @MVector [5, 11]
2626

2727
m5 = @SMatrix [1.0 2.0; 3.0 4.0]
2828
v7 = [1.0, 2.0]
@@ -62,11 +62,11 @@
6262

6363
m = @SArray [1 2; 3 4]
6464
n = @SArray [2 3; 4 5]
65-
@test m*n === @SArray [10 13; 22 29]
65+
@test m*n === @SMatrix [10 13; 22 29]
6666

6767
m = @MArray [1 2; 3 4]
6868
n = @MArray [2 3; 4 5]
69-
@test (m*n)::MArray == @MArray [10 13; 22 29]
69+
@test (m*n)::MMatrix == @MMatrix [10 13; 22 29]
7070

7171
# Alternative methods used between 8 < n <= 14 and n > 14
7272
m_array = rand(1:10, 10, 10)

0 commit comments

Comments
 (0)