diff --git a/hypercoast/pace.py b/hypercoast/pace.py index 28b530bb..74e5c7bc 100644 --- a/hypercoast/pace.py +++ b/hypercoast/pace.py @@ -976,12 +976,13 @@ def apply_sam( n_components: int = 3, n_clusters: int = 6, random_state: int = 0, + spectral_library: Union[str, list[str]] = None, filter_condition: Optional[Callable[[xr.DataArray], xr.DataArray]] = None, plot: bool = True, figsize: tuple[int, int] = (8, 6), extent: list[float] = None, colors: list[str] = None, - title: str = "Spectral Angle Mapper", + title: str = None, **kwargs, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: """ @@ -992,6 +993,7 @@ def apply_sam( n_components (int, optional): Number of principal components to compute. Defaults to 3. n_clusters (int, optional): Number of clusters for K-means. Defaults to 6. random_state (int, optional): Random state for K-means. Defaults to 0. + spectral_library (Union[str, list[str]], optional): Path to the spectral library or a list of paths. Defaults to None. filter_condition (Callable[[xr.DataArray], xr.DataArray], optional): A function to filter the data. Defaults to None. plot (bool, optional): Whether to plot the data. Defaults to True. figsize (Tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6). @@ -1003,11 +1005,14 @@ def apply_sam( Returns: Tuple[np.ndarray, np.ndarray, np.ndarray]: The best match classification, latitudes, and longitudes. """ + import glob + import pandas as pd from sklearn.cluster import KMeans import matplotlib.colors as mcolors import cartopy.crs as ccrs import cartopy.feature as cfeature from sklearn.decomposition import PCA + from scipy.interpolate import interp1d if isinstance(dataset, str): dataset = read_pace(dataset) @@ -1017,6 +1022,7 @@ def apply_sam( raise ValueError("dataset must be an xarray Dataset") da = dataset["Rrs"] + pace_wavelengths = da["wavelength"].values # Reshape data to (n_pixels, n_bands) reshaped_data = da.values.reshape(-1, da.shape[-1]) @@ -1024,16 +1030,53 @@ def apply_sam( # Handle NaNs by removing them reshaped_data_no_nan = reshaped_data[~np.isnan(reshaped_data).any(axis=1)] - # Apply PCA to reduce dimensionality - pca = PCA(n_components=n_components) - pca_data = pca.fit_transform(reshaped_data_no_nan) + if isinstance(spectral_library, str): + endmember_paths = sorted(glob.glob(spectral_library)) + elif isinstance(spectral_library, list): + endmember_paths = spectral_library + else: + endmember_paths = None + + # Function to load and resample a single CSV spectral library file + def load_and_resample_spectral_library(csv_path, target_wavelengths): + df = pd.read_csv(csv_path) + original_wavelengths = df.iloc[:, 0].values # First column is wavelength + spectra_values = df.iloc[:, 1].values # Second column is spectral values + + # Interpolation function + interp_func = interp1d( + original_wavelengths, + spectra_values, + kind="linear", + fill_value="extrapolate", + ) + + # Resample to the target (PACE) wavelengths + resampled_spectra = interp_func(target_wavelengths) + + return resampled_spectra + + if endmember_paths is not None: + endmembers = np.array( + [ + load_and_resample_spectral_library(path, pace_wavelengths) + for path in endmember_paths + ] + ) + else: + endmembers = None + + if endmembers is None: + # Apply PCA to reduce dimensionality + pca = PCA(n_components=n_components) + pca_data = pca.fit_transform(reshaped_data_no_nan) - # Apply K-means to find clusters representing endmembers - kmeans = KMeans(n_clusters=n_clusters, random_state=random_state) - kmeans.fit(pca_data) + # Apply K-means to find clusters representing endmembers + kmeans = KMeans(n_clusters=n_clusters, random_state=random_state) + kmeans.fit(pca_data) - # The cluster centers in the original spectral space are your endmembers - endmembers = pca.inverse_transform(kmeans.cluster_centers_) + # The cluster centers in the original spectral space are your endmembers + endmembers = pca.inverse_transform(kmeans.cluster_centers_) def spectral_angle_mapper(pixel, reference): norm_pixel = np.linalg.norm(pixel) @@ -1066,12 +1109,50 @@ def spectral_angle_mapper(pixel, reference): latitudes = da.coords["latitude"].values longitudes = da.coords["longitude"].values + # Plot sample spectra from the CSV files and their resampled versions + def plot_sample_spectra(csv_paths, pace_wavelengths): + plt.figure(figsize=figsize) + + for i, csv_path in enumerate(csv_paths): + df = pd.read_csv(csv_path) + original_wavelengths = df.iloc[:, 0].values + spectra_values = df.iloc[:, 1].values + resampled_spectra = load_and_resample_spectral_library( + csv_path, pace_wavelengths + ) + + plt.plot( + original_wavelengths, + spectra_values, + label=f"Original Spectra {i+1}", + linestyle="--", + ) + plt.plot( + pace_wavelengths, resampled_spectra, label=f"Resampled Spectra {i+1}" + ) + + plt.xlabel("Wavelength (nm)") + plt.ylabel("Spectral Reflectance") + plt.title("Comparison of Original and Resampled Spectra") + plt.legend() + plt.grid(True) + plt.show() + if plot: + if endmember_paths is not None: + + plot_sample_spectra(endmember_paths, pace_wavelengths) + if colors is None: colors = ["#377eb8", "#ff7f00", "#4daf4a", "#f781bf", "#a65628", "#984ea3"] + + if title is None: + title = "Spectral Angle Mapper Water Type Classification" # Create a custom discrete color map cmap = mcolors.ListedColormap(colors) + if spectral_library is not None: + n_clusters = len(endmember_paths) bounds = np.arange(-0.5, n_clusters, 1) norm = mcolors.BoundaryNorm(bounds, cmap.N) @@ -1121,3 +1202,211 @@ def spectral_angle_mapper(pixel, reference): plt.show() return best_match_full, latitudes, longitudes + + +def apply_sam_spectral( + dataset: Union[xr.Dataset, str], + spectral_library: Union[str, list[str]] = None, + filter_condition: Optional[Callable[[xr.DataArray], xr.DataArray]] = None, + plot: bool = True, + figsize: tuple[int, int] = (8, 6), + extent: list[float] = None, + colors: list[str] = None, + title: str = None, + **kwargs, +) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Applies Spectral Angle Mapper (SAM) to the dataset and optionally plots the results. + + Args: + dataset (Union[xr.Dataset, str]): The dataset containing the PACE data or the file path to the dataset. + spectral_library (Union[str, list[str]]): The spectral library file path or list of file paths. + filter_condition (Callable[[xr.DataArray], xr.DataArray], optional): A function to filter the data. Defaults to None. + plot (bool, optional): Whether to plot the data. Defaults to True. + figsize (Tuple[int, int], optional): Figure size for the plot. Defaults to (8, 6). + extent (List[float], optional): The extent to zoom in to the specified region. Defaults to None. + colors (List[str], optional): Colors for the clusters. Defaults to None. + title (str, optional): Title for the plot. Defaults to "Spectral Angle Mapper Water Type Classification". + **kwargs: Additional keyword arguments to pass to the `plt.subplots` function. + + Returns: + Tuple[np.ndarray, np.ndarray, np.ndarray]: The best match classification, latitudes, and longitudes. + """ + import glob + import pandas as pd + import matplotlib.colors as mcolors + import cartopy.crs as ccrs + import cartopy.feature as cfeature + from scipy.interpolate import interp1d + + if isinstance(dataset, str): + dataset = read_pace(dataset) + elif isinstance(dataset, xr.DataArray): + dataset = dataset.to_dataset() + elif not isinstance(dataset, xr.Dataset): + raise ValueError("dataset must be an xarray Dataset") + + da = dataset["Rrs"] + pace_wavelengths = da["wavelength"].values + + if isinstance(spectral_library, str): + endmember_paths = sorted(glob.glob(spectral_library)) + elif isinstance(spectral_library, list): + endmember_paths = spectral_library + else: + endmember_paths = None + + # Function to load and resample a single CSV spectral library file + def load_and_resample_spectral_library(csv_path, target_wavelengths): + df = pd.read_csv(csv_path) + original_wavelengths = df.iloc[:, 0].values # First column is wavelength + spectra_values = df.iloc[:, 1].values # Second column is spectral values + + # Interpolation function + interp_func = interp1d( + original_wavelengths, + spectra_values, + kind="linear", + fill_value="extrapolate", + ) + + # Resample to the target (PACE) wavelengths + resampled_spectra = interp_func(target_wavelengths) + + return resampled_spectra + + if endmember_paths is not None: + endmembers = np.array( + [ + load_and_resample_spectral_library(path, pace_wavelengths) + for path in endmember_paths + ] + ) + else: + endmembers = None + + # Function to calculate spectral angle + def spectral_angle_mapper(pixel, reference): + norm_pixel = np.linalg.norm(pixel) + norm_reference = np.linalg.norm(reference) + cos_theta = np.dot(pixel, reference) / (norm_pixel * norm_reference) + angle = np.arccos(np.clip(cos_theta, -1, 1)) + return angle + + # Reshape data to (n_pixels, n_bands) + reshaped_data = da.values.reshape(-1, da.shape[-1]) + + # Apply SAM for each pixel and each endmember + angles = np.zeros((reshaped_data.shape[0], endmembers.shape[0])) + + for i in range(reshaped_data.shape[0]): + for j in range(endmembers.shape[0]): + angles[i, j] = spectral_angle_mapper(reshaped_data[i, :], endmembers[j, :]) + + # Find the minimum angle (best match) for each pixel + best_match = np.argmin(angles, axis=1) + + # Reshape best_match back to the original spatial dimensions + best_match = best_match.reshape(da.shape[:-1]) + + if filter_condition is not None: + best_match = np.where(filter_condition, best_match, np.nan) + + latitudes = da.coords["latitude"].values + longitudes = da.coords["longitude"].values + + # Plot sample spectra from the CSV files and their resampled versions + def plot_sample_spectra(csv_paths, pace_wavelengths): + plt.figure(figsize=figsize) + + for i, csv_path in enumerate(csv_paths): + df = pd.read_csv(csv_path) + original_wavelengths = df.iloc[:, 0].values + spectra_values = df.iloc[:, 1].values + resampled_spectra = load_and_resample_spectral_library( + csv_path, pace_wavelengths + ) + + plt.plot( + original_wavelengths, + spectra_values, + label=f"Original Spectra {i+1}", + linestyle="--", + ) + plt.plot( + pace_wavelengths, resampled_spectra, label=f"Resampled Spectra {i+1}" + ) + + plt.xlabel("Wavelength (nm)") + plt.ylabel("Spectral Reflectance") + plt.title("Comparison of Original and Resampled Spectra") + plt.legend() + plt.grid(True) + plt.show() + + if plot: + + if endmember_paths is not None: + + plot_sample_spectra(endmember_paths, pace_wavelengths) + + if colors is None: + colors = ["#377eb8", "#e41a1c", "#4daf4a", "#f781bf", "#a65628", "#984ea3"] + + if title is None: + title = "Spectral Angle Mapper Water Type Classification" + # Create a custom discrete color map + cmap = mcolors.ListedColormap(colors) + bounds = np.arange(-0.5, len(endmembers), 1) + norm = mcolors.BoundaryNorm(bounds, cmap.N) + + # Create a figure and axis with the correct map projection + _, ax = plt.subplots( + figsize=figsize, subplot_kw={"projection": ccrs.PlateCarree()}, **kwargs + ) + + # Plot the SAM classification results + im = ax.pcolormesh( + longitudes, + latitudes, + best_match, + cmap=cmap, + norm=norm, + transform=ccrs.PlateCarree(), + ) + + # Add geographic features for context + ax.add_feature(cfeature.COASTLINE) + ax.add_feature(cfeature.BORDERS, linestyle=":") + ax.add_feature(cfeature.STATES, linestyle="--") + + # Adding axis labels + ax.set_xlabel("Longitude") + ax.set_ylabel("Latitude") + + # Adding a title + ax.set_title(title, fontsize=14) + + # Adding a color bar with discrete values + cbar = plt.colorbar( + im, + ax=ax, + orientation="vertical", + # pad=0.02, + fraction=0.05, + ticks=np.arange(len(endmembers)), + ) + cbar.ax.set_yticklabels([f"Class {i+1}" for i in range(len(endmembers))]) + cbar.set_label("Water Types", rotation=270, labelpad=20) + + # Adding gridlines + ax.gridlines(draw_labels=True, linestyle="--", linewidth=0.5) + + # Set the extent to zoom in to the specified region (adjust as needed) + if extent is not None: + ax.set_extent(extent, crs=ccrs.PlateCarree()) + + # Show the plot + plt.show() + + return best_match, latitudes, longitudes