Skip to content

Commit 102ba6f

Browse files
authored
Merge pull request #30 from MotionbyLearning/17_sparse_selection
17 sparse selection
2 parents 4dfd08f + de4b146 commit 102ba6f

File tree

3 files changed

+310
-10
lines changed

3 files changed

+310
-10
lines changed

examples/scripts/script_ps_selection.py examples/scripts/script_depsi_processing.py

+41-9
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@
1717
import sarxarray
1818
import stmtools
1919

20-
from pydepsi.classification import ps_selection
20+
from pydepsi.io import read_metadata
21+
from pydepsi.classification import ps_selection, network_stm_selection
2122

2223
# Make a logger to log the stages of processing
2324
logger = logging.getLogger(__name__)
@@ -37,14 +38,19 @@ def get_free_port():
3738

3839

3940
# ---- Config 1: Human Input ----
40-
41-
42-
# Parameters
43-
method = 'nmad' # Method for selection
44-
threshold = 0.45 # Threshold for selection
45-
4641
# Input data paths
4742
path_slc_zarr = Path("/project/caroline/slc_file.zarr") # Zarr file of all SLCs
43+
path_metadata = Path("/project/caroline/metadata.res") # Metadata file
44+
45+
# Parameters PS selection
46+
ps_selection_method = 'nmad' # Method for PS selection
47+
ps_selection_threshold = 0.45 # Threshold for PS selection
48+
49+
# Parameters network selection
50+
network_stm_quality_metric = 'nmad' # Quality metric for network selection
51+
network_stm_quality_threshold = 0.45 # Quality threshold for network selection
52+
min_dist = 200 # Distance threshold for network selection, in meters
53+
include_index = [57, 101, 189] # Force including the points with index 57, 101, and 189, use None if no point need to be included
4854

4955
# Output config
5056
overwrite_zarr = False # Flag for zarr overwrite
@@ -87,8 +93,10 @@ def get_free_port():
8793
)
8894

8995
if __name__ == "__main__":
96+
# ---- Processing Stage 0: Initialization ----
9097
logger.info("Initializing ...")
9198

99+
# Initiate a Dask client
92100
if cluster is None:
93101
# Use existing cluster
94102
client = Client(ADDRESS)
@@ -98,8 +106,13 @@ def get_free_port():
98106
cluster.scale(jobs=N_WORKERS)
99107
client = Client(cluster)
100108

109+
# Load metadata
110+
metadata = read_metadata(path_metadata)
111+
112+
# ---- Processing Stage 1: Pixel Classification ----
101113
# Load the SLC data
102-
logger.info("Loading data ...")
114+
logger.info("Processing Stage 1: Pixel Classification")
115+
logger.info("Loading SLC data ...")
103116
ds = xr.open_zarr(path_slc_zarr) # Load the zarr file as a xr.Dataset
104117
# Construct SLCs from xr.Dataset
105118
# this construct three datavariables: complex, amplitude, and phase
@@ -110,16 +123,35 @@ def get_free_port():
110123
# slcs = slcs.chunk({"azimuth":1000, "range":1000, "time":-1})
111124

112125
# Select PS
113-
stm_ps = ps_selection(method, threshold, method='nmad', output_chunks=chunk_space)
126+
logger.info("PS Selection ...")
127+
stm_ps = ps_selection(method, threshold, method=ps_selection_method, output_chunks=chunk_space)
114128

115129
# Re-order the PS to make the spatially adjacent PS in the same chunk
130+
logger.info("Reorder selected scatterers ...")
116131
stm_ps_reordered = stm_ps.stm.reorder(xlabel='lon', ylabel='lat')
117132

118133
# Save the PS to zarr
134+
logger.info("Writting selected pixels to Zarr ...")
119135
if overwrite_zarr:
120136
stm_ps_reordered.to_zarr(path_ps_zarr, mode="w")
121137
else:
122138
stm_ps_reordered.to_zarr(path_ps_zarr)
123139

140+
# ---- Processing Stage 2: Network Processing ----
141+
# Uncomment the following line to load the PS data from zarr
142+
# stm_ps_reordered = xr.open_zarr(path_ps_zarr)
143+
144+
# Select network points
145+
logger.info("Select network scatterers ...")
146+
# Apply a pre-filter
147+
stm_network_candidates = xr.where(stm_ps_reordered[network_stm_quality_metric]<network_stm_quality_threshold)
148+
# Select based on sparsity and quality
149+
stm_network = network_stm_selection(stm_network_candidates,
150+
min_dist,
151+
include_index=include_index,
152+
sortby_var=network_stm_quality_metric,
153+
azimuth_spacing=metadata['azimuth_spacing'],
154+
range_spacing=metadata['range_spacing'])
155+
124156
# Close the client when finishing
125157
client.close()

pydepsi/classification.py

+148
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44

55
import numpy as np
66
import xarray as xr
7+
from scipy.spatial import KDTree
78

89

910
def ps_selection(
@@ -126,6 +127,126 @@ def ps_selection(
126127
return stm_masked
127128

128129

130+
def network_stm_selection(
131+
stm: xr.Dataset,
132+
min_dist: int | float,
133+
include_index: list[int] = None,
134+
sortby_var: str = "pnt_nmad",
135+
crs: int | str = "radar",
136+
x_var: str = "azimuth",
137+
y_var: str = "range",
138+
azimuth_spacing: float = None,
139+
range_spacing: float = None,
140+
):
141+
"""Select a Space-Time Matrix (STM) from a candidate STM for network processing.
142+
143+
The selection is based on two criteria:
144+
1. A minimum distance between selected points.
145+
2. A sorting metric to select better points.
146+
147+
The candidate STM will be sorted by the sorting metric.
148+
The selection will be performed iteratively, starting from the best point.
149+
In each iteration, the best point will be selected, and points within the minimum distance will be removed.
150+
The process will continue until no points are left in the candidate STM.
151+
152+
Parameters
153+
----------
154+
stm : xr.Dataset
155+
candidate Space-Time Matrix (STM).
156+
min_dist : int | float
157+
Minimum distance between selected points.
158+
include_index : list[int], optional
159+
Index of points in the candidate STM that must be included in the selection, by default None
160+
sortby_var : str, optional
161+
Sorting metric for selecting points, by default "pnt_nmad"
162+
crs : int | str, optional
163+
EPSG code of Coordinate Reference System of `x_var` and `y_var`, by default "radar".
164+
If crs is "radar", the distance will be calculated based on radar coordinates, and
165+
azimuth_spacing and range_spacing must be provided.
166+
x_var : str, optional
167+
Data variable name for x coordinate, by default "azimuth"
168+
y_var : str, optional
169+
Data variable name for y coordinate, by default "range"
170+
azimuth_spacing : float, optional
171+
Azimuth spacing, by default None. Required if crs is "radar".
172+
range_spacing : float, optional
173+
Range spacing, by default None. Required if crs is "radar".
174+
175+
Returns
176+
-------
177+
xr.Dataset
178+
Selected network Space-Time Matrix (STM).
179+
180+
Raises
181+
------
182+
ValueError
183+
Raised when `azimuth_spacing` or `range_spacing` is not provided for radar coordinates.
184+
NotImplementedError
185+
Raised when an unsupported Coordinate Reference System is provided.
186+
"""
187+
match crs:
188+
case "radar":
189+
if (azimuth_spacing is None) or (range_spacing is None):
190+
raise ValueError("Azimuth and range spacing must be provided for radar coordinates.")
191+
case _:
192+
raise NotImplementedError
193+
194+
# Get coordinates and sorting metric, load them into memory
195+
stm_select = None
196+
stm_remain = stm[[x_var, y_var, sortby_var]].compute()
197+
198+
# Select the include_index if provided
199+
if include_index is not None:
200+
stm_select = stm_remain.isel(space=include_index)
201+
202+
# Remove points within min_dist of the included points
203+
coords_include = np.column_stack(
204+
[stm_select["azimuth"].values * azimuth_spacing, stm_select["range"].values * range_spacing]
205+
)
206+
coords_remain = np.column_stack(
207+
[stm_remain["azimuth"].values * azimuth_spacing, stm_remain["range"].values * range_spacing]
208+
)
209+
idx_drop = _idx_within_distance(coords_include, coords_remain, min_dist)
210+
if idx_drop is not None:
211+
stm_remain = stm_remain.where(~(stm_remain["space"].isin(idx_drop)), drop=True)
212+
213+
# Reorder the remaining points by the sorting metric
214+
stm_remain = stm_remain.sortby(sortby_var)
215+
216+
# Build a list of the index of selected points
217+
if stm_select is None:
218+
space_idx_sel = []
219+
else:
220+
space_idx_sel = stm_select["space"].values.tolist()
221+
222+
while stm_remain.sizes["space"] > 0:
223+
# Select one point with best sorting metric
224+
stm_now = stm_remain.isel(space=0)
225+
226+
# Append the selected point index
227+
space_idx_sel.append(stm_now["space"].values.tolist())
228+
229+
# Remove the selected point from the remaining points
230+
stm_remain = stm_remain.isel(space=slice(1, None)).copy()
231+
232+
# Remove points in stm_remain within min_dist of stm_now
233+
coords_remain = np.column_stack(
234+
[stm_remain["azimuth"].values * azimuth_spacing, stm_remain["range"].values * range_spacing]
235+
)
236+
coords_stmnow = np.column_stack(
237+
[stm_now["azimuth"].values * azimuth_spacing, stm_now["range"].values * range_spacing]
238+
)
239+
idx_drop = _idx_within_distance(coords_stmnow, coords_remain, min_dist)
240+
if idx_drop is not None:
241+
stm_drop = stm_remain.isel(space=idx_drop)
242+
stm_remain = stm_remain.where(~(stm_remain["space"].isin(stm_drop["space"])), drop=True)
243+
244+
# Get the selected points by space index from the original stm
245+
stm_out = stm.sel(space=space_idx_sel)
246+
247+
return stm_out
248+
249+
129250
def _nad_block(amp: xr.DataArray) -> xr.DataArray:
130251
"""Compute Normalized Amplitude Dispersion (NAD) for a block of amplitude data.
131252
@@ -170,3 +291,30 @@ def _nmad_block(amp: xr.DataArray) -> xr.DataArray:
170291
nmad = mad / (median_amplitude + np.finfo(amp.dtype).eps) # Normalized Median Absolute Deviation
171292

172293
return nmad
294+
295+
296+
def _idx_within_distance(coords_ref, coords_others, min_dist):
297+
"""Get the index of points in coords_others that are within min_dist of coords_ref.
298+
299+
Parameters
300+
----------
301+
coords_ref : np.ndarray
302+
Coordinates of reference points. Shape (n, 2).
303+
coords_others : np.ndarray
304+
Coordinates of other points. Shape (m, 2).
305+
min_dist : int, float
306+
distance threshold.
307+
308+
Returns
309+
-------
310+
np.ndarray
311+
Index of points in coords_others that are within `min_dist` of `coords_ref`.
312+
"""
313+
kd_ref = KDTree(coords_ref)
314+
kd_others = KDTree(coords_others)
315+
sdm = kd_ref.sparse_distance_matrix(kd_others, min_dist)
316+
if len(sdm) > 0:
317+
idx = np.array(list(sdm.keys()))[:, 1]
318+
return idx
319+
else:
320+
return None

0 commit comments

Comments
 (0)