Skip to content

Commit

Permalink
Add chla_predict notebook
Browse files Browse the repository at this point in the history
  • Loading branch information
giswqs committed Dec 8, 2024
1 parent 7dd2b13 commit 4f276d3
Show file tree
Hide file tree
Showing 2 changed files with 238 additions and 0 deletions.
236 changes: 236 additions & 0 deletions docs/examples/chla_predict.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,236 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"[![image](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/opengeos/HyperCoast/blob/main/docs/examples/chla_predict.ipynb)\n",
"\n",
"# Chlorophyll-a prediction with Deep Learning\n",
"\n",
"## Install packages\n",
"\n",
"Uncomment the following cell to install the required packages."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# %pip install \"hypercoast[all]\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Import libraries"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import torch\n",
"import hypercoast\n",
"from hypercoast.chla import *"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Download sample data"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chla_data_url = \"https://github.com/opengeos/datasets/releases/download/hypercoast/chla_test_data.zip\"\n",
"pace_data_url = \"https://github.com/opengeos/datasets/releases/download/hypercoast/PACE_OCI.20241024T182127.L2.OC_AOP.V2_0.NRT.nc\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"hypercoast.download_file(chla_data_url, quiet=True)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"pace_filepath = hypercoast.download_file(pace_data_url, quiet=True)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Train the model"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"device = torch.device(\"cuda\" if torch.cuda.is_available() else \"cpu\")\n",
"# Load the training dataset\n",
"train_real_dl, test_real_dl, input_dim, output_dim = load_real_data(\n",
" \"data/Chl_RC_PACE.csv\", \"data/Rrs_RC_PACE.csv\"\n",
")\n",
"# Load the validation dataset\n",
"test_real_Sep, _, _ = load_real_test(\n",
" \"data/Chl_RC_PACE_Sep.csv\", \"data/Rrs_RC_PACE_Sep.csv\"\n",
")\n",
"# Model output path.\n",
"save_dir = \"model/VAE_Chla_PACE\"\n",
"os.makedirs(save_dir, exist_ok=True)\n",
"\n",
"# Create the VAE model and optimizer\n",
"model = VAE(input_dim, output_dim).to(device)\n",
"opt = torch.optim.Adam(model.parameters(), lr=0.001, weight_decay=1e-3)\n",
"\n",
"best_model_path = \"model/vae_trans_model_best_Chl_PACE.pth\"\n",
"train(model, train_real_dl, epochs=400, best_model_path=best_model_path)\n",
"# Load the optimal model\n",
"model.load_state_dict(torch.load(best_model_path, map_location=device))\n",
"\n",
"predictions, actuals = evaluate(model, test_real_dl)\n",
"\n",
"predictions_Sep, actuals_Sep = evaluate(model, test_real_Sep)\n",
"\n",
"save_to_csv(predictions, os.path.join(save_dir, \"predictions.csv\"))\n",
"save_to_csv(actuals, os.path.join(save_dir, \"actuals.csv\"))"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_results(predictions, actuals, save_dir, mode=\"test\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![image](https://github.com/user-attachments/assets/8766a2a9-8bdf-4348-a544-7115cfb95f63)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"plot_results(predictions_Sep, actuals_Sep, save_dir, mode=\"Sep\")"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![image](https://github.com/user-attachments/assets/2e16bbc5-9b8c-46ce-88f5-9fb6187fe5af)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Predict chlorophyll-a concentration"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chla_data_file = pace_filepath.replace(\".nc\", \".npy\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chla_predict(pace_filepath, best_model_path, chla_data_file, device)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Visualize the results"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"rgb_image_tif_file = \"data/snapshot-2024-08-10T00_00_00Z.tif\"\n",
"output_file = \"20241024-2.png\"\n",
"title = \"PACE Chla Prediction\"\n",
"figsize = (12, 8)\n",
"cmap = \"jet\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"chla_viz(rgb_image_tif_file, chla_data_file, output_file, title, figsize, cmap)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"![image](https://github.com/user-attachments/assets/93944a8d-8784-441a-9a51-9050780eb22e)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "hyper",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.12.8"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
2 changes: 2 additions & 0 deletions mkdocs.yml
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ plugins:
"image_slicing.ipynb",
"temperature.ipynb",
"pace_oci_l1.ipynb",
"chla_predict.ipynb",
]

markdown_extensions:
Expand Down Expand Up @@ -101,6 +102,7 @@ nav:
- examples/pace.ipynb
- examples/ecostress.ipynb
- examples/chlorophyll_a.ipynb
- examples/chla_predict.ipynb
- examples/temperature.ipynb
- examples/pace_oci_l1.ipynb
- examples/pace_oci_l2.ipynb
Expand Down

0 comments on commit 4f276d3

Please sign in to comment.