Skip to content

Commit

Permalink
Add apply_sam_spectral function
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Oct 13, 2024
1 parent ce49c9c commit db562fe
Showing 1 changed file with 298 additions and 9 deletions.
307 changes: 298 additions & 9 deletions hypercoast/pace.py
Original file line number Diff line number Diff line change
Expand Up @@ -976,12 +976,13 @@ def apply_sam(
n_components: int = 3,
n_clusters: int = 6,
random_state: int = 0,
spectral_library: Union[str, list[str]] = None,
filter_condition: Optional[Callable[[xr.DataArray], xr.DataArray]] = None,
plot: bool = True,
figsize: tuple[int, int] = (8, 6),
extent: list[float] = None,
colors: list[str] = None,
title: str = "Spectral Angle Mapper",
title: str = None,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Expand All @@ -992,6 +993,7 @@ def apply_sam(
n_components (int, optional): Number of principal components to compute. Defaults to 3.
n_clusters (int, optional): Number of clusters for K-means. Defaults to 6.
random_state (int, optional): Random state for K-means. Defaults to 0.
spectral_library (Union[str, list[str]], optional): Path to the spectral library or a list of paths. Defaults to None.
filter_condition (Callable[[xr.DataArray], xr.DataArray], optional): A function to filter the data. Defaults to None.
plot (bool, optional): Whether to plot the data. Defaults to True.
figsize (Tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6).
Expand All @@ -1003,11 +1005,14 @@ def apply_sam(
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: The best match classification, latitudes, and longitudes.
"""
import glob
import pandas as pd
from sklearn.cluster import KMeans
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from sklearn.decomposition import PCA
from scipy.interpolate import interp1d

if isinstance(dataset, str):
dataset = read_pace(dataset)
Expand All @@ -1017,23 +1022,61 @@ def apply_sam(
raise ValueError("dataset must be an xarray Dataset")

da = dataset["Rrs"]
pace_wavelengths = da["wavelength"].values

# Reshape data to (n_pixels, n_bands)
reshaped_data = da.values.reshape(-1, da.shape[-1])

# Handle NaNs by removing them
reshaped_data_no_nan = reshaped_data[~np.isnan(reshaped_data).any(axis=1)]

# Apply PCA to reduce dimensionality
pca = PCA(n_components=n_components)
pca_data = pca.fit_transform(reshaped_data_no_nan)
if isinstance(spectral_library, str):
endmember_paths = sorted(glob.glob(spectral_library))
elif isinstance(spectral_library, list):
endmember_paths = spectral_library
else:
endmember_paths = None

# Function to load and resample a single CSV spectral library file
def load_and_resample_spectral_library(csv_path, target_wavelengths):
df = pd.read_csv(csv_path)
original_wavelengths = df.iloc[:, 0].values # First column is wavelength
spectra_values = df.iloc[:, 1].values # Second column is spectral values

# Interpolation function
interp_func = interp1d(
original_wavelengths,
spectra_values,
kind="linear",
fill_value="extrapolate",
)

# Resample to the target (PACE) wavelengths
resampled_spectra = interp_func(target_wavelengths)

return resampled_spectra

if endmember_paths is not None:
endmembers = np.array(
[
load_and_resample_spectral_library(path, pace_wavelengths)
for path in endmember_paths
]
)
else:
endmembers = None

if endmembers is None:
# Apply PCA to reduce dimensionality
pca = PCA(n_components=n_components)
pca_data = pca.fit_transform(reshaped_data_no_nan)

# Apply K-means to find clusters representing endmembers
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
kmeans.fit(pca_data)
# Apply K-means to find clusters representing endmembers
kmeans = KMeans(n_clusters=n_clusters, random_state=random_state)
kmeans.fit(pca_data)

# The cluster centers in the original spectral space are your endmembers
endmembers = pca.inverse_transform(kmeans.cluster_centers_)
# The cluster centers in the original spectral space are your endmembers
endmembers = pca.inverse_transform(kmeans.cluster_centers_)

def spectral_angle_mapper(pixel, reference):
norm_pixel = np.linalg.norm(pixel)
Expand Down Expand Up @@ -1066,12 +1109,50 @@ def spectral_angle_mapper(pixel, reference):
latitudes = da.coords["latitude"].values
longitudes = da.coords["longitude"].values

# Plot sample spectra from the CSV files and their resampled versions
def plot_sample_spectra(csv_paths, pace_wavelengths):
plt.figure(figsize=figsize)

for i, csv_path in enumerate(csv_paths):
df = pd.read_csv(csv_path)
original_wavelengths = df.iloc[:, 0].values
spectra_values = df.iloc[:, 1].values
resampled_spectra = load_and_resample_spectral_library(
csv_path, pace_wavelengths
)

plt.plot(
original_wavelengths,
spectra_values,
label=f"Original Spectra {i+1}",
linestyle="--",
)
plt.plot(
pace_wavelengths, resampled_spectra, label=f"Resampled Spectra {i+1}"
)

plt.xlabel("Wavelength (nm)")
plt.ylabel("Spectral Reflectance")
plt.title("Comparison of Original and Resampled Spectra")
plt.legend()
plt.grid(True)
plt.show()

if plot:

if endmember_paths is not None:

plot_sample_spectra(endmember_paths, pace_wavelengths)

if colors is None:
colors = ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"]

if title is None:
title = "Spectral Angle Mapper Water Type Classification"
# Create a custom discrete color map
cmap = mcolors.ListedColormap(colors)
if spectral_library is not None:
n_clusters = len(endmember_paths)
bounds = np.arange(-0.5, n_clusters, 1)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

Expand Down Expand Up @@ -1121,3 +1202,211 @@ def spectral_angle_mapper(pixel, reference):
plt.show()

return best_match_full, latitudes, longitudes


def apply_sam_spectral(
dataset: Union[xr.Dataset, str],
spectral_library: Union[str, list[str]] = None,
filter_condition: Optional[Callable[[xr.DataArray], xr.DataArray]] = None,
plot: bool = True,
figsize: tuple[int, int] = (8, 6),
extent: list[float] = None,
colors: list[str] = None,
title: str = None,
**kwargs,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
"""
Applies Spectral Angle Mapper (SAM) to the dataset and optionally plots the results.
Args:
dataset (Union[xr.Dataset, str]): The dataset containing the PACE data or the file path to the dataset.
spectral_library (Union[str, list[str]]): The spectral library file path or list of file paths.
filter_condition (Callable[[xr.DataArray], xr.DataArray], optional): A function to filter the data. Defaults to None.
plot (bool, optional): Whether to plot the data. Defaults to True.
figsize (Tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6).
extent (List[float], optional): The extent to zoom in to the specified region. Defaults to None.
colors (List[str], optional): Colors for the clusters. Defaults to None.
title (str, optional): Title for the plot. Defaults to "Spectral Angle Mapper Water Type Classification".
**kwargs: Additional keyword arguments to pass to the `plt.subplots` function.
Returns:
Tuple[np.ndarray, np.ndarray, np.ndarray]: The best match classification, latitudes, and longitudes.
"""
import glob
import pandas as pd
import matplotlib.colors as mcolors
import cartopy.crs as ccrs
import cartopy.feature as cfeature
from scipy.interpolate import interp1d

if isinstance(dataset, str):
dataset = read_pace(dataset)
elif isinstance(dataset, xr.DataArray):
dataset = dataset.to_dataset()
elif not isinstance(dataset, xr.Dataset):
raise ValueError("dataset must be an xarray Dataset")

da = dataset["Rrs"]
pace_wavelengths = da["wavelength"].values

if isinstance(spectral_library, str):
endmember_paths = sorted(glob.glob(spectral_library))
elif isinstance(spectral_library, list):
endmember_paths = spectral_library
else:
endmember_paths = None

# Function to load and resample a single CSV spectral library file
def load_and_resample_spectral_library(csv_path, target_wavelengths):
df = pd.read_csv(csv_path)
original_wavelengths = df.iloc[:, 0].values # First column is wavelength
spectra_values = df.iloc[:, 1].values # Second column is spectral values

# Interpolation function
interp_func = interp1d(
original_wavelengths,
spectra_values,
kind="linear",
fill_value="extrapolate",
)

# Resample to the target (PACE) wavelengths
resampled_spectra = interp_func(target_wavelengths)

return resampled_spectra

if endmember_paths is not None:
endmembers = np.array(
[
load_and_resample_spectral_library(path, pace_wavelengths)
for path in endmember_paths
]
)
else:
endmembers = None

# Function to calculate spectral angle
def spectral_angle_mapper(pixel, reference):
norm_pixel = np.linalg.norm(pixel)
norm_reference = np.linalg.norm(reference)
cos_theta = np.dot(pixel, reference) / (norm_pixel * norm_reference)
angle = np.arccos(np.clip(cos_theta, -1, 1))
return angle

# Reshape data to (n_pixels, n_bands)
reshaped_data = da.values.reshape(-1, da.shape[-1])

# Apply SAM for each pixel and each endmember
angles = np.zeros((reshaped_data.shape[0], endmembers.shape[0]))

for i in range(reshaped_data.shape[0]):
for j in range(endmembers.shape[0]):
angles[i, j] = spectral_angle_mapper(reshaped_data[i, :], endmembers[j, :])

# Find the minimum angle (best match) for each pixel
best_match = np.argmin(angles, axis=1)

# Reshape best_match back to the original spatial dimensions
best_match = best_match.reshape(da.shape[:-1])

if filter_condition is not None:
best_match = np.where(filter_condition, best_match, np.nan)

latitudes = da.coords["latitude"].values
longitudes = da.coords["longitude"].values

# Plot sample spectra from the CSV files and their resampled versions
def plot_sample_spectra(csv_paths, pace_wavelengths):
plt.figure(figsize=figsize)

for i, csv_path in enumerate(csv_paths):
df = pd.read_csv(csv_path)
original_wavelengths = df.iloc[:, 0].values
spectra_values = df.iloc[:, 1].values
resampled_spectra = load_and_resample_spectral_library(
csv_path, pace_wavelengths
)

plt.plot(
original_wavelengths,
spectra_values,
label=f"Original Spectra {i+1}",
linestyle="--",
)
plt.plot(
pace_wavelengths, resampled_spectra, label=f"Resampled Spectra {i+1}"
)

plt.xlabel("Wavelength (nm)")
plt.ylabel("Spectral Reflectance")
plt.title("Comparison of Original and Resampled Spectra")
plt.legend()
plt.grid(True)
plt.show()

if plot:

if endmember_paths is not None:

plot_sample_spectra(endmember_paths, pace_wavelengths)

if colors is None:
colors = ["#377eb8", "#e41a1c", "#4daf4a", "#f781bf", "#a65628", "#984ea3"]

if title is None:
title = "Spectral Angle Mapper Water Type Classification"
# Create a custom discrete color map
cmap = mcolors.ListedColormap(colors)
bounds = np.arange(-0.5, len(endmembers), 1)
norm = mcolors.BoundaryNorm(bounds, cmap.N)

# Create a figure and axis with the correct map projection
_, ax = plt.subplots(
figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, **kwargs
)

# Plot the SAM classification results
im = ax.pcolormesh(
longitudes,
latitudes,
best_match,
cmap=cmap,
norm=norm,
transform=ccrs.PlateCarree(),
)

# Add geographic features for context
ax.add_feature(cfeature.COASTLINE)
ax.add_feature(cfeature.BORDERS, linestyle=":")
ax.add_feature(cfeature.STATES, linestyle="--")

# Adding axis labels
ax.set_xlabel("Longitude")
ax.set_ylabel("Latitude")

# Adding a title
ax.set_title(title, fontsize=14)

# Adding a color bar with discrete values
cbar = plt.colorbar(
im,
ax=ax,
orientation="vertical",
# pad=0.02,
fraction=0.05,
ticks=np.arange(len(endmembers)),
)
cbar.ax.set_yticklabels([f"Class {i+1}" for i in range(len(endmembers))])
cbar.set_label("Water Types", rotation=270, labelpad=20)

# Adding gridlines
ax.gridlines(draw_labels=True, linestyle="--", linewidth=0.5)

# Set the extent to zoom in to the specified region (adjust as needed)
if extent is not None:
ax.set_extent(extent, crs=ccrs.PlateCarree())

# Show the plot
plt.show()

return best_match, latitudes, longitudes

0 comments on commit db562fe

Please sign in to comment.