generated from geo-smart/sample_project_repository
-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
283 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,283 @@ | ||
{ | ||
"cells": [ | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [ | ||
"import xarray as xr\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import cartopy.crs as ccrs\n", | ||
"import cartopy.feature as cfeature\n", | ||
"\n", | ||
"# Load the saved Zarr data\n", | ||
"zarr_path = '~/shared-public/mind_the_chl_gap/U-Net_with_CHL_pred.zarr'\n", | ||
"zarr_ds = xr.open_zarr(zarr_path)['gapfree_pred']\n", | ||
"\n", | ||
"# Select the date you want to plot\n", | ||
"date_to_plot = '2022-01-01' # Replace with the desired date\n", | ||
"zarr_date = zarr_ds.sel(time=date_to_plot)\n", | ||
"\n", | ||
"# Load the Level 3 CHL data\n", | ||
"level3_path = '~/shared-public/mind_the_chl_gap/IO.zarr'\n", | ||
"level3_ds = xr.open_zarr(level3_path)\n", | ||
"level3_chl = level3_ds['CHL_cmes-level3'].sel(time=date_to_plot)\n", | ||
"sst = level3_ds['sst'].sel(time=date_to_plot)\n", | ||
"u_wind = level3_ds['u_wind'].sel(time=date_to_plot)\n", | ||
"v_wind = level3_ds['v_wind'].sel(time=date_to_plot)\n", | ||
"air_temp = level3_ds['air_temp'].sel(time=date_to_plot)\n", | ||
"ug_curr = level3_ds['ug_curr'].sel(time=date_to_plot)\n", | ||
"# Plot the data\n", | ||
"fig, axes = plt.subplots(nrows=7, ncols=1, figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()})\n", | ||
"\n", | ||
"# Plot the log-scaled Level 3 CHL data\n", | ||
"ax = axes[0]\n", | ||
"level3_chl_log = np.log(level3_chl.where(~np.isnan(level3_chl), np.nan))\n", | ||
"im = ax.imshow(level3_chl_log, vmin=np.nanmin(level3_chl_log), vmax=np.nanmax(level3_chl_log), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax.set_title('Log-scaled Level 3 CHL')\n", | ||
"ax.add_feature(cfeature.COASTLINE)\n", | ||
"ax.set_xlabel('Longitude')\n", | ||
"ax.set_ylabel('Latitude')\n", | ||
"ax = axes[1]\n", | ||
"im = ax.imshow(sst, vmin=np.nanmin(sst), vmax=np.nanmax(sst), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax.set_title('SST')\n", | ||
"ax.add_feature(cfeature.COASTLINE)\n", | ||
"ax.set_xlabel('Longitude')\n", | ||
"ax.set_ylabel('Latitude')\n", | ||
"ax = axes[2]\n", | ||
"im = ax.imshow(u_wind, vmin=np.nanmin(u_wind), vmax=np.nanmax(u_wind), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax.set_title('u_wind')\n", | ||
"ax.add_feature(cfeature.COASTLINE)\n", | ||
"ax.set_xlabel('Longitude')\n", | ||
"ax.set_ylabel('Latitude')\n", | ||
"ax = axes[3]\n", | ||
"im = ax.imshow(v_wind, vmin=np.nanmin(v_wind), vmax=np.nanmax(v_wind), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax.set_title('v_wind')\n", | ||
"ax.add_feature(cfeature.COASTLINE)\n", | ||
"ax.set_xlabel('Longitude')\n", | ||
"ax.set_ylabel('Latitude')\n", | ||
"ax = axes[4]\n", | ||
"im = ax.imshow(air_temp, vmin=np.nanmin(air_temp), vmax=np.nanmax(air_temp), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax.set_title('air_temp')\n", | ||
"ax.add_feature(cfeature.COASTLINE)\n", | ||
"ax.set_xlabel('Longitude')\n", | ||
"ax.set_ylabel('Latitude')\n", | ||
"ax = axes[5]\n", | ||
"im = ax.imshow(ug_curr, vmin=np.nanmin(ug_curr), vmax=np.nanmax(ug_curr), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax.set_title('ug_curr')\n", | ||
"ax.add_feature(cfeature.COASTLINE)\n", | ||
"ax.set_xlabel('Longitude')\n", | ||
"ax.set_ylabel('Latitude')\n", | ||
"ax = axes[6]\n", | ||
"gapfill_chl_log = zarr_date\n", | ||
"im = ax.imshow(gapfill_chl_log, vmin=np.nanmin(gapfill_chl_log), vmax=np.nanmax(gapfill_chl_log), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax.set_title('Log-scaled U-Net Gapfilled CHL Prediction')\n", | ||
"ax.add_feature(cfeature.COASTLINE)\n", | ||
"ax.set_xlabel('Longitude')\n", | ||
"ax.set_ylabel('Latitude')\n", | ||
"\n", | ||
"fig.colorbar(im, ax=axes.ravel().tolist(), location='right', shrink=0.9)\n", | ||
"plt.show()" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": 6, | ||
"metadata": {}, | ||
"outputs": [ | ||
{ | ||
"name": "stdout", | ||
"output_type": "stream", | ||
"text": [ | ||
"Requirement already satisfied: deepxde in /srv/conda/envs/notebook/lib/python3.11/site-packages (1.12.0)\n", | ||
"Requirement already satisfied: matplotlib in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (3.8.0)\n", | ||
"Requirement already satisfied: numpy in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.24.4)\n", | ||
"Requirement already satisfied: scikit-learn in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.3.0)\n", | ||
"Requirement already satisfied: scikit-optimize>=0.9.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (0.10.2)\n", | ||
"Requirement already satisfied: scipy in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.11.2)\n", | ||
"Requirement already satisfied: joblib>=0.11 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (1.3.2)\n", | ||
"Requirement already satisfied: pyaml>=16.9 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (24.7.0)\n", | ||
"Requirement already satisfied: packaging>=21.3 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (23.1)\n", | ||
"Requirement already satisfied: threadpoolctl>=2.0.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-learn->deepxde) (3.2.0)\n", | ||
"Requirement already satisfied: contourpy>=1.0.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (1.1.1)\n", | ||
"Requirement already satisfied: cycler>=0.10 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (0.11.0)\n", | ||
"Requirement already satisfied: fonttools>=4.22.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (4.42.1)\n", | ||
"Requirement already satisfied: kiwisolver>=1.0.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (1.4.5)\n", | ||
"Requirement already satisfied: pillow>=6.2.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (9.5.0)\n", | ||
"Requirement already satisfied: pyparsing>=2.3.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (3.1.1)\n", | ||
"Requirement already satisfied: python-dateutil>=2.7 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (2.8.2)\n", | ||
"Requirement already satisfied: PyYAML in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pyaml>=16.9->scikit-optimize>=0.9.0->deepxde) (6.0.1)\n", | ||
"Requirement already satisfied: six>=1.5 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->deepxde) (1.16.0)\n", | ||
"Note: you may need to restart the kernel to use updated packages.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"pip install deepxde" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": { | ||
"scrolled": true | ||
}, | ||
"outputs": [ | ||
{ | ||
"name": "stderr", | ||
"output_type": "stream", | ||
"text": [ | ||
"Using backend: pytorch\n", | ||
"Other supported backends: tensorflow.compat.v1, tensorflow, jax, paddle.\n", | ||
"paddle supports more examples now and is recommended.\n" | ||
] | ||
} | ||
], | ||
"source": [ | ||
"import numpy as np\n", | ||
"import xarray as xr\n", | ||
"import os\n", | ||
"os.environ[\"DDEBACKEND\"] = \"pytorch\"\n", | ||
"import deepxde as dde\n", | ||
"import matplotlib.pyplot as plt\n", | ||
"import cartopy.crs as ccrs\n", | ||
"import cartopy.feature as cfeature\n", | ||
"from tqdm import tqdm\n", | ||
"import numpy as np\n", | ||
"import xarray as xr\n", | ||
"import deepxde as dde\n", | ||
"\n", | ||
"# Load the data\n", | ||
"zarr_path = '~/shared-public/mind_the_chl_gap/U-Net_with_CHL_pred.zarr'\n", | ||
"zarr_ds = xr.open_zarr(zarr_path)['gapfree_pred']\n", | ||
"\n", | ||
"level3_path = '~/shared-public/mind_the_chl_gap/IO.zarr'\n", | ||
"level3_ds = xr.open_zarr(level3_path)\n", | ||
"\n", | ||
"# Prepare the input data (v)\n", | ||
"variables = ['CHL_cmes-level3', 'sst', 'u_wind', 'v_wind', 'air_temp', 'ug_curr']\n", | ||
"input_data = []\n", | ||
"\n", | ||
"for var in variables:\n", | ||
" data = level3_ds[var].values\n", | ||
" data = np.log(data) if var == 'CHL_cmes-level3' else data\n", | ||
" input_data.append(data)\n", | ||
"\n", | ||
"v = np.stack(input_data, axis=-1)\n", | ||
"\n", | ||
"# Prepare the output data (u)\n", | ||
"u = zarr_ds.values\n", | ||
"\n", | ||
"# Reshape the data\n", | ||
"v = v.reshape(-1, v.shape[-1]) # (num_points, num_variables)\n", | ||
"u = u.reshape(-1) # (num_points,)\n", | ||
"\n", | ||
"# Split the data into training and testing sets\n", | ||
"n_train = int(0.8 * len(u))\n", | ||
"X_train, y_train = (v[:n_train], np.zeros((n_train, 1))), u[:n_train]\n", | ||
"X_test, y_test = (v[n_train:], np.zeros((len(u) - n_train, 1))), u[n_train:]\n", | ||
"# Set up the data\n", | ||
"data = dde.data.TripleCartesianProd(\n", | ||
" X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test\n", | ||
")\n", | ||
"\n", | ||
"# Define the DeepONet architecture\n", | ||
"m = v.shape[1] # number of input variables\n", | ||
"dim_x = 1 # dimension of spatial input (in this case, just a placeholder)\n", | ||
"\n", | ||
"net = dde.nn.DeepONetCartesianProd(\n", | ||
" [m, 64, 64], # branch net\n", | ||
" [dim_x, 64, 64], # trunk net\n", | ||
" \"relu\",\n", | ||
" \"Glorot normal\",\n", | ||
")\n", | ||
"\n", | ||
"# Create the model\n", | ||
"model = dde.Model(data, net)\n", | ||
"\n", | ||
"# Compile the model\n", | ||
"model.compile(\"adam\", lr=0.001, metrics=[\"mean l2 relative error\"])\n", | ||
"\n", | ||
"# Create a custom loss function\n", | ||
"def custom_loss(inputs, outputs, targets):\n", | ||
" return torch.mean((outputs - targets)**2)\n", | ||
"\n", | ||
"# Create the model\n", | ||
"model = dde.Model(data, net)\n", | ||
"print(v.dtype)\n", | ||
"print(x.dtype)\n", | ||
"print(u.dtype)\n", | ||
"\n", | ||
"# Compile the model\n", | ||
"model.compile(\"adam\", lr=0.001, loss=custom_loss, metrics=[\"mean l2 relative error\"])\n", | ||
"\n", | ||
"print(\"Training the model...\")\n", | ||
"\n", | ||
"# Train the model\n", | ||
"losshistory, train_state = model.train(iterations=10000, batch_size=32) \n", | ||
"\n", | ||
"print(\"Making predictions...\")\n", | ||
"\n", | ||
"# Make predictions\n", | ||
"y_pred = model.predict(X_test)\n", | ||
"\n", | ||
"# Reshape the predictions back to the original shape\n", | ||
"y_pred = y_pred.reshape(zarr_ds.shape[1:])\n", | ||
"\n", | ||
"print(\"Visualizing results...\")\n", | ||
"\n", | ||
"# Visualize the results\n", | ||
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), subplot_kw={'projection': ccrs.PlateCarree()})\n", | ||
"\n", | ||
"# True gap-filled CHL\n", | ||
"im1 = ax1.imshow(zarr_ds.isel(time=0), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax1.set_title('True Gap-filled CHL')\n", | ||
"ax1.add_feature(cfeature.COASTLINE)\n", | ||
"ax1.set_xlabel('Longitude')\n", | ||
"ax1.set_ylabel('Latitude')\n", | ||
"fig.colorbar(im1, ax=ax1)\n", | ||
"\n", | ||
"# Predicted gap-filled CHL\n", | ||
"im2 = ax2.imshow(y_pred, extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n", | ||
"ax2.set_title('Predicted Gap-filled CHL')\n", | ||
"ax2.add_feature(cfeature.COASTLINE)\n", | ||
"ax2.set_xlabel('Longitude')\n", | ||
"ax2.set_ylabel('Latitude')\n", | ||
"fig.colorbar(im2, ax=ax2)\n", | ||
"\n", | ||
"plt.show()\n", | ||
"\n", | ||
"print(\"Process completed.\")" | ||
] | ||
}, | ||
{ | ||
"cell_type": "code", | ||
"execution_count": null, | ||
"metadata": {}, | ||
"outputs": [], | ||
"source": [] | ||
} | ||
], | ||
"metadata": { | ||
"kernelspec": { | ||
"display_name": "Python 3 (ipykernel)", | ||
"language": "python", | ||
"name": "python3" | ||
}, | ||
"language_info": { | ||
"codemirror_mode": { | ||
"name": "ipython", | ||
"version": 3 | ||
}, | ||
"file_extension": ".py", | ||
"mimetype": "text/x-python", | ||
"name": "python", | ||
"nbconvert_exporter": "python", | ||
"pygments_lexer": "ipython3", | ||
"version": "3.11.5" | ||
} | ||
}, | ||
"nbformat": 4, | ||
"nbformat_minor": 4 | ||
} |