Skip to content

Commit

Permalink
Current development version 23/08
Browse files Browse the repository at this point in the history
  • Loading branch information
dwest77a committed Aug 23, 2024
1 parent ddc0e55 commit 4f1a045
Show file tree
Hide file tree
Showing 2 changed files with 190 additions and 35 deletions.
56 changes: 56 additions & 0 deletions CFAPyX/subactive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
from XarrayActive import DaskActiveArray
import numpy as np

class SubDaskActiveArray(DaskActiveArray):
description = 'Lower level nested dask array - requires alterations to methods.'

def copy(self):
"""
Create a new SubDaskActiveArray instance with all the same parameters as the current instance.
"""
copy_arr = SubDaskActiveArray(self.dask, self.name, self.chunks, meta=self)
return copy_arr

def __getitem__(self, index):
"""
Perform indexing for this ActiveArray. May need to overwrite further if it turns out
the indexing is performed **after** the dask `getter` method (i.e if retrieval and indexing
are separate items on the dask graph). If this is the case, will need another `from_delayed`
and `concatenation` method as used in ``active_mean``.
"""
arr = super().__getitem__(index)
return SubDaskActiveArray(arr.dask, arr.name, arr.chunks, meta=arr)

def _numel(self, axes=None):
if not axes:
return self.size

size = 1
for i in axes:
size *= self.shape[i]
newshape = list(self.shape)
for ax in axes:
newshape[ax] = 1

return np.full(newshape, size)

def active_mean(self, axis=None, skipna=None):
"""
Perform ``dask delayed`` active mean for each ``dask block`` which corresponds to a single ``chunk``.
Combines the results of the dask delayed ``active_mean`` operations on each block into a single dask Array,
which is then mapped to a new DaskActiveArray object.
:param axis: (int) The index of the axis on which to perform the active mean.
:param skipna: (bool) Skip NaN values when calculating the mean.
:returns: A new ``DaskActiveArray`` object which has been reduced along the specified axes using
the concatenations of active_means from each chunk.
"""
n = self._numel(axes=axis)
total = super().active_mean(axis=axis, skipna=skipna)

return {
'n': n,
'total': total * n
}
169 changes: 134 additions & 35 deletions CFAPyX/wrappers.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
combine_slices
)

from .subactive import SubDaskActiveArray

import dask.array as da
from dask.array.core import getter
from dask.base import tokenize
Expand Down Expand Up @@ -51,7 +53,8 @@ def cfa_options(self):
'substitutions': self._substitutions,
'decode_cfa': self._decode_cfa,
'chunks': self.chunks,
'chunk_limits':self._chunk_limits
'chunk_limits':self._chunk_limits,
'use_active':self.use_active
}

@cfa_options.setter
Expand All @@ -77,16 +80,20 @@ def _set_cfa_options(
self.chunks = chunks
self.use_active = use_active

def _assemble_array(self, dsk, array_name, dask_chunks):
def _assemble_array(self, dsk, array_name, dask_chunks, subarray=False):

meta = np.empty(self.shape, dtype=self.dtype)
meta = da.empty(shape=self.shape, dtype=self.dtype)
if not hasattr(self, 'use_active'):
darr = da.Array(dsk, array_name, chunks=dask_chunks, dtype=self.dtype, meta=meta)
return darr

if not self.use_active:
darr = da.Array(dsk, array_name, chunks=dask_chunks, dtype=self.dtype, meta=meta)
return darr

if subarray:
darr = SubDaskActiveArray(dsk, array_name, chunks=dask_chunks, dtype=self.dtype, meta=meta)
return darr
try:
from XarrayActive import DaskActiveArray

Expand All @@ -98,7 +105,7 @@ def _assemble_array(self, dsk, array_name, dask_chunks):
)
return darr

class FragmentArrayWrapper(ArrayLike, CFAArrayWrapper, ActiveOptionsContainer):
class FragmentArrayWrapper(ArrayLike, CFAArrayWrapper):
"""
FragmentArrayWrapper behaves like an Array that can be indexed or referenced to
return a Dask-like array object. This class is essentially a constructor for the
Expand Down Expand Up @@ -254,18 +261,28 @@ def __array__(self):
self.shape
)

positions = get_chunk_positions(chunk_space)
partition_space = get_partition_space(
self.fragment_space,
chunk_space
)

#positions = get_chunk_positions(chunk_space)

extents = {}
for p in positions:
#extents = {}
#for p in positions:
# Each chunk fits into the whole fragment array
extents[p] = get_chunk_extent(p, self.shape, chunk_space)
#extents[p] = get_chunk_extent(p, self.shape, chunk_space)

dsk = self._chunk_oversample(fragments, extents, array_name)
dsk, extents = self._chunk_oversample(
fragments,
array_name,
chunk_shape,
chunk_space
)

dask_chunks = get_dask_chunks(
self.shape,
chunk_space,
partition_space,
extent=extents, # List of extents
dtype=self.dtype
)
Expand Down Expand Up @@ -303,7 +320,13 @@ def _chunk_by_fragment(self, fragments, array_name):
)
return dsk

def _chunk_oversample(self, fragments, extents, array_name):
def _chunk_oversample(
self,
fragments,
#extents,
array_name,
cs,
chunk_space):
"""
Assemble the base ``dsk`` task dependency graph which includes the chunk
objects plus the method to index each chunk object (with locking). In this
Expand All @@ -324,22 +347,22 @@ def _chunk_oversample(self, fragments, extents, array_name):
constructing the dask array.
"""

cs = get_chunk_shape(
self.chunks,
self.shape,
self.named_dims,
chunk_limits=self._chunk_limits
)
chunk_space = get_chunk_space(
cs,
self.shape
)

origin = get_chunk_extent(
tuple([0 for i in range(self.ndim)]),
self.shape,
chunk_space
)
#cs = get_chunk_shape(
# self.chunks,
# self.shape,
# self.named_dims,
# chunk_limits=self._chunk_limits
#)
#chunk_space = get_chunk_space(
# cs,
# self.shape
#)

#origin = get_chunk_extent(
# tuple([0 for i in range(self.ndim)]),
# self.shape,
# chunk_space
#)
mfwrapper = {}

for fragment_coord in fragments.keys():
Expand All @@ -348,8 +371,9 @@ def _chunk_oversample(self, fragments, extents, array_name):

initial, final = [],[]
for dim in range(len(fragment_coord)):

conversion = chunk_space[dim]/self.fragment_space[dim]

# Divide specific chunk sizes
conversion = fs[dim]/cs[dim]

initial.append(
int(fragment_coord[dim] * conversion)
Expand Down Expand Up @@ -385,7 +409,35 @@ def _chunk_oversample(self, fragments, extents, array_name):
mfwrapper[c] = [newfragment]

dsk = {}
extents = {}

for chunk in mfwrapper.keys():
for fragment in mfwrapper[chunk]:

extent = fragment.global_extent
pposition = get_partition_coord(
self.fragment_space,
chunk_space,
self.shape,
extent)
if pposition in extents:
pass
extents[pposition] = fragment.get_extent()
fragment.position = pposition

mf_identifier = f"{fragment.__class__.__name__}-{tokenize(fragment)}"
dsk[mf_identifier] = fragment
dsk[array_name + pposition] = (
getter, # From dask docs - replaces fragment_getter
mf_identifier,
fragment.get_extent(),
False,
False
)

return dsk, extents

"""
for chunk in mfwrapper.keys():
fragments = mfwrapper[chunk]
mfwrap = CFAChunkWrapper(
Expand All @@ -410,6 +462,7 @@ def _chunk_oversample(self, fragments, extents, array_name):
False,
False
)
"""
return dsk

def _apply_substitutions(self):
Expand All @@ -424,7 +477,7 @@ def _apply_substitutions(self):
for f in self.fragment_info.keys():
self.fragment_info[f]['location'] = self.fragment_info[f]['location'].replace(base, substitution)

class CFAChunkWrapper(ArrayLike, CFAArrayWrapper, ActiveOptionsContainer):
class CFAChunkWrapper(ArrayLike, CFAArrayWrapper):
description = 'Brand new array class for handling any-size dask chunks.'

"""
Expand Down Expand Up @@ -486,7 +539,7 @@ def __array__(self):
# f_indices is the initial_extent for the ArrayPartition
extents[fragment_position] = fragment.get_extent()

origin = tuple([slice(0,i) for i in self.shape])
origin = tuple([slice(0,i) for i in fragment.shape])

f_identifier = f"{fragment.__class__.__name__}-{tokenize(fragment)}"
dsk[f_identifier] = fragment
Expand All @@ -505,7 +558,7 @@ def __array__(self):
self.dtype
)

return self._assemble_array(dsk, array_name[0], dask_chunks)
return self._assemble_array(dsk, array_name[0], dask_chunks, subarray=True)

def _organise_fragments(self, fragments):

Expand Down Expand Up @@ -655,7 +708,53 @@ def _overlap_in_1d(chunk, chunk_size, fragment, fragment_size):

# From start and end subtract
translation = fragment[0]*fragment_size
start = max(chunk[0]*chunk_size, fragment[0]*fragment_size) - translation
end = min(chunk[1]*chunk_size, fragment[1]*fragment_size) - translation
start = max(chunk[0]*chunk_size, fragment[0]*fragment_size) + translation
end = min(chunk[1]*chunk_size, fragment[1]*fragment_size) + translation

return slice(start, end) # And possibly more
if end < 0 or end < start:
pass

return slice(start, end) # And possibly more

def get_partition_space(fragment_space, chunk_space):
pspace = []
for dim in range(len(fragment_space)):
pspace.append(len(partition_1d(
fragment_space[dim],
chunk_space[dim]
)))
return pspace

def get_partition_coord(fragment_space, chunk_space, array_space, extent):
pcoord = [None for din in range(len(fragment_space))]
for dim in range(len(fragment_space)):

divset = sorted(
[0] + list(partition_1d(
fragment_space[dim],
chunk_space[dim]
))
)
position = int(extent[dim].start * divset[-1]/array_space[dim])
for x in range(len(divset)):
if divset[x] == position:
pcoord[dim] = x
if pcoord[dim] == None:
pass
return tuple(pcoord)

def partition_1d(space_1, space_2, lim=None):
total = space_1 * space_2

set1 = [i for i in range(space_1, total+1, space_1)]
set2 = [j for j in range(space_2, total+1, space_2)]

tset = set(set1 + set2)
if lim == None:
return tset

rset = []
for item in tset:
if item < lim:
rset.append(item)
return set(rset)

0 comments on commit 4f1a045

Please sign in to comment.