Skip to content

WIP: enhance parent_call #248

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions docs/src/reference.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,8 @@ OffsetMatrix
OffsetArrays.Origin
OffsetArrays.IdOffsetRange
OffsetArrays.no_offset_view
OffsetArrays.centered
OffsetArrays.center
OffsetArrays.no_offset_view_apply
OffsetArrays.AxisConversionStyle
```
81 changes: 67 additions & 14 deletions src/OffsetArrays.jl
Original file line number Diff line number Diff line change
Expand Up @@ -298,14 +298,6 @@ if VERSION < v"1.6"
end
end

# Utils to translate a function to the parent while preserving offsets
unwrap(x) = x, identity
unwrap(x::OffsetArray) = parent(x), data -> OffsetArray(data, x.offsets, checkoverflow = false)
function parent_call(f, x)
parent, wrap_offset = unwrap(x)
wrap_offset(f(parent))
end

Base.similar(A::OffsetArray, ::Type{T}, dims::Dims) where T =
similar(parent(A), T, dims)
function Base.similar(A::AbstractArray, ::Type{T}, shape::Tuple{OffsetAxisKnownLength,Vararg{OffsetAxisKnownLength}}) where T
Expand Down Expand Up @@ -372,8 +364,8 @@ Base.trues(inds::NTuple{N, Union{Integer, AbstractUnitRange}}) where {N} =
Base.falses(inds::NTuple{N, Union{Integer, AbstractUnitRange}}) where {N} =
fill!(similar(BitArray, inds), false)

Base.zero(A::OffsetArray) = parent_call(zero, A)
Base.fill!(A::OffsetArray, x) = parent_call(Ap -> fill!(Ap, x), A)
Base.zero(A::OffsetArray) = no_offset_view_apply(zero, A)
Base.fill!(A::OffsetArray, x) = no_offset_view_apply(Ap -> fill!(Ap, x), A)

## Indexing

Expand All @@ -393,7 +385,7 @@ parentindex(r::IdOffsetRange, i) = i - r.offset
end

@propagate_inbounds Base.getindex(A::OffsetArray{<:Any,N}, c::Vararg{Colon,N}) where N =
parent_call(x -> getindex(x, c...), A)
no_offset_view_apply(x -> getindex(x, c...), A)

# With one Colon we use linear indexing.
# In this case we may forward the index to the parent, as the information about the axes is lost
Expand Down Expand Up @@ -425,7 +417,7 @@ end
end

Base.in(x, A::OffsetArray) = in(x, parent(A))
Base.copy(A::OffsetArray) = parent_call(copy, A)
Base.copy(A::OffsetArray) = no_offset_view_apply(copy, A)

Base.strides(A::OffsetArray) = strides(parent(A))
Base.elsize(::Type{OffsetArray{T,N,A}}) where {T,N,A} = Base.elsize(A)
Expand Down Expand Up @@ -530,7 +522,7 @@ end

# eltype conversion
# This may use specialized map methods for the parent
Base.map(::Type{T}, O::OffsetArray) where {T} = parent_call(x -> map(T, x), O)
Base.map(::Type{T}, O::OffsetArray) where {T} = no_offset_view_apply(x -> map(T, x), O)
Base.map(::Type{T}, r::IdOffsetRange) where {T<:Real} = _indexedby(map(T, UnitRange(r)), axes(r))
if eltype(IIUR) === Int
# This is type-piracy, but there is no way to convert an IdentityUnitRange to a non-Int type in Base
Expand Down Expand Up @@ -634,6 +626,65 @@ _no_offset_view(::Tuple{<:Base.OneTo,Vararg{<:Base.OneTo}}, A::AbstractUnitRange
_no_offset_view(::Any, A::AbstractArray) = OffsetArray(A, Origin(1))
_no_offset_view(::Any, A::AbstractUnitRange) = UnitRange(A)

# Utils to translate a function to the parent while preserving offsets
function unwrap(x, xs...)
AT1 = _basetype(typeof(x))
all(x->AT1==_basetype(typeof(x)), xs) || throw(ArgumentError("All arrays should be homogeneous, i.e., have the same container type."))
_unwrap(x, xs...)
end
function _unwrap(x::AT, xs...) where AT<:OffsetArray
o = x.offsets
all(x->o==x.offsets, xs) || throw(DimensionMismatch("All arrays should have the same offsets."))
wrap_offset(data) = OffsetArray(data, o, checkoverflow=false)
(parent(x), map(parent, xs)...), wrap_offset
end
_unwrap(xs...) = xs, identity
"""
no_offset_view_apply(f, As...) -> C

Apply one-output function `f` to [`no_offset_view`](@ref OffsetArrays.no_offset_view)s of `As` while
preserve the offsets for its output.

For `OffsetArray`s input, it's almost equivalent to `OffsetArray(f(map(parent, As)...), As[1].offsets)`

# Examples

One can use this to convert the internal array type:

```jldoctest; setup=:(using OffsetArrays)
Ao = OffsetArray(CartesianIndices((4, 4)), -1, -1);
parent(Ao) isa Array # false

# collect(Ao): strips the offsets
# Array(Ao): DimensionMismatch
dense_Ao = OffsetArrays.no_offset_view_apply(Array, Ao);
parent(dense_Ao) isa Array

# output
true
```

Also, one can do some 1-based mathematics on it:

```jldoctest; setup=:(using OffsetArrays)
A = OffsetArray(rand(4, 4), -1, -1)
B = OffsetArray(rand(4, 4), -1, -1)

C = OffsetArrays.no_offset_view_apply(A, B) do x, y
x * y
end

parent(C) == parent(A) * parent(B)

# output
true
```
"""
function no_offset_view_apply(f, xs...)
parent, wrap_offset = unwrap(xs...)
wrap_offset(f(parent...))
end

#####
# center/centered
# These two helpers are deliberately not exported; their meaning can be very different in
Expand Down Expand Up @@ -702,6 +753,8 @@ instead use [`center`](@ref OffsetArrays.center).
"""
centered(A::AbstractArray, r::RoundingMode=RoundDown) = OffsetArray(A, .-center(A, r))

# might be available in Base: https://github.com/JuliaLang/julia/issues/35543
_basetype(::Type{T}) where T = Base.typename(T).wrapper

####
# work around for segfault in searchsorted*
Expand Down Expand Up @@ -786,7 +839,7 @@ end
# Adapt allows for automatic conversion of CPU OffsetArrays to GPU OffsetArrays
##
import Adapt
Adapt.adapt_structure(to, O::OffsetArray) = parent_call(x -> Adapt.adapt(to, x), O)
Adapt.adapt_structure(to, O::OffsetArray) = no_offset_view_apply(x -> Adapt.adapt(to, x), O)

if Base.VERSION >= v"1.4.2"
include("precompile.jl")
Expand Down
2 changes: 1 addition & 1 deletion test/runtests.jl
Original file line number Diff line number Diff line change
Expand Up @@ -864,7 +864,7 @@ end
@testset "unwrap" begin
for A in [ones(2, 2), ones(2:3, 2:3), ZeroBasedRange(1:4)]
p, f = OffsetArrays.unwrap(A)
@test f(map(y -> y^2, p)) == A.^2
@test f(map(y -> y^2, p...)) == A.^2
end
end

Expand Down