Skip to content

Commit 5a9c7f7

Browse files
authored
scalar indexing broadcast fix (#442)
* scalar indexing broadcast fix * add tests for scalar broadcast indexing * add test for ref
1 parent f560b63 commit 5a9c7f7

File tree

2 files changed

+17
-1
lines changed

2 files changed

+17
-1
lines changed

src/broadcast.jl

+5-1
Original file line numberDiff line numberDiff line change
@@ -135,6 +135,10 @@ end
135135
end
136136
end
137137

138+
scalar_getindex(x) = x
139+
scalar_getindex(x::Ref) = x[]
140+
scalar_getindex(x::Tuple{<: Any}) = x[1]
141+
138142
@generated function _broadcast(f, ::Size{newsize}, s::Tuple{Vararg{Size}}, a...) where newsize
139143
first_staticarray = 0
140144
for i = 1:length(a)
@@ -150,7 +154,7 @@ end
150154
sizes = [sz.parameters[1] for sz s.parameters]
151155

152156
while more
153-
exprs_vals = [(!(a[i] <: AbstractArray) ? :(a[$i][]) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
157+
exprs_vals = [(!(a[i] <: AbstractArray) ? :(scalar_getindex(a[$i])) : :(a[$i][$(broadcasted_index(sizes[i], current_ind))])) for i = 1:length(sizes)]
154158
exprs[current_ind...] = :(f($(exprs_vals...)))
155159

156160
# increment current_ind (maybe use CartesianIndices?)

test/broadcast.jl

+12
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,15 @@
1+
using StaticArrays, Base.Test
2+
struct ScalarTest end
3+
Base.:(+)(x::Number, y::ScalarTest) = x
4+
5+
@testset "Scalar Broadcast" begin
6+
for t in (SVector{2}, MVector{2}, SMatrix{2, 2}, MMatrix{2, 2})
7+
x = rand(t)
8+
@test x == @inferred(x .+ ScalarTest())
9+
@test x .+ 1 == @inferred(x .+ Ref(1))
10+
end
11+
end
12+
113
@testset "Broadcast sizes" begin
214
@test @inferred(StaticArrays.broadcast_sizes(1, 1, 1)) === (Size(), Size(), Size())
315
for t in (SVector{2}, MVector{2}, SMatrix{2, 2}, MMatrix{2, 2})

0 commit comments

Comments
 (0)