Skip to content

Commit

Permalink
?
Browse files Browse the repository at this point in the history
  • Loading branch information
eeholmes committed Aug 22, 2024
2 parents e1749ab + eaae67e commit 9919838
Showing 1 changed file with 283 additions and 0 deletions.
283 changes: 283 additions & 0 deletions notebooks/Data_Drive_PINN.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,283 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import xarray as xr\n",
"import matplotlib.pyplot as plt\n",
"import cartopy.crs as ccrs\n",
"import cartopy.feature as cfeature\n",
"\n",
"# Load the saved Zarr data\n",
"zarr_path = '~/shared-public/mind_the_chl_gap/U-Net_with_CHL_pred.zarr'\n",
"zarr_ds = xr.open_zarr(zarr_path)['gapfree_pred']\n",
"\n",
"# Select the date you want to plot\n",
"date_to_plot = '2022-01-01' # Replace with the desired date\n",
"zarr_date = zarr_ds.sel(time=date_to_plot)\n",
"\n",
"# Load the Level 3 CHL data\n",
"level3_path = '~/shared-public/mind_the_chl_gap/IO.zarr'\n",
"level3_ds = xr.open_zarr(level3_path)\n",
"level3_chl = level3_ds['CHL_cmes-level3'].sel(time=date_to_plot)\n",
"sst = level3_ds['sst'].sel(time=date_to_plot)\n",
"u_wind = level3_ds['u_wind'].sel(time=date_to_plot)\n",
"v_wind = level3_ds['v_wind'].sel(time=date_to_plot)\n",
"air_temp = level3_ds['air_temp'].sel(time=date_to_plot)\n",
"ug_curr = level3_ds['ug_curr'].sel(time=date_to_plot)\n",
"# Plot the data\n",
"fig, axes = plt.subplots(nrows=7, ncols=1, figsize=(12, 6), subplot_kw={'projection': ccrs.PlateCarree()})\n",
"\n",
"# Plot the log-scaled Level 3 CHL data\n",
"ax = axes[0]\n",
"level3_chl_log = np.log(level3_chl.where(~np.isnan(level3_chl), np.nan))\n",
"im = ax.imshow(level3_chl_log, vmin=np.nanmin(level3_chl_log), vmax=np.nanmax(level3_chl_log), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax.set_title('Log-scaled Level 3 CHL')\n",
"ax.add_feature(cfeature.COASTLINE)\n",
"ax.set_xlabel('Longitude')\n",
"ax.set_ylabel('Latitude')\n",
"ax = axes[1]\n",
"im = ax.imshow(sst, vmin=np.nanmin(sst), vmax=np.nanmax(sst), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax.set_title('SST')\n",
"ax.add_feature(cfeature.COASTLINE)\n",
"ax.set_xlabel('Longitude')\n",
"ax.set_ylabel('Latitude')\n",
"ax = axes[2]\n",
"im = ax.imshow(u_wind, vmin=np.nanmin(u_wind), vmax=np.nanmax(u_wind), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax.set_title('u_wind')\n",
"ax.add_feature(cfeature.COASTLINE)\n",
"ax.set_xlabel('Longitude')\n",
"ax.set_ylabel('Latitude')\n",
"ax = axes[3]\n",
"im = ax.imshow(v_wind, vmin=np.nanmin(v_wind), vmax=np.nanmax(v_wind), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax.set_title('v_wind')\n",
"ax.add_feature(cfeature.COASTLINE)\n",
"ax.set_xlabel('Longitude')\n",
"ax.set_ylabel('Latitude')\n",
"ax = axes[4]\n",
"im = ax.imshow(air_temp, vmin=np.nanmin(air_temp), vmax=np.nanmax(air_temp), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax.set_title('air_temp')\n",
"ax.add_feature(cfeature.COASTLINE)\n",
"ax.set_xlabel('Longitude')\n",
"ax.set_ylabel('Latitude')\n",
"ax = axes[5]\n",
"im = ax.imshow(ug_curr, vmin=np.nanmin(ug_curr), vmax=np.nanmax(ug_curr), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax.set_title('ug_curr')\n",
"ax.add_feature(cfeature.COASTLINE)\n",
"ax.set_xlabel('Longitude')\n",
"ax.set_ylabel('Latitude')\n",
"ax = axes[6]\n",
"gapfill_chl_log = zarr_date\n",
"im = ax.imshow(gapfill_chl_log, vmin=np.nanmin(gapfill_chl_log), vmax=np.nanmax(gapfill_chl_log), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax.set_title('Log-scaled U-Net Gapfilled CHL Prediction')\n",
"ax.add_feature(cfeature.COASTLINE)\n",
"ax.set_xlabel('Longitude')\n",
"ax.set_ylabel('Latitude')\n",
"\n",
"fig.colorbar(im, ax=axes.ravel().tolist(), location='right', shrink=0.9)\n",
"plt.show()"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"Requirement already satisfied: deepxde in /srv/conda/envs/notebook/lib/python3.11/site-packages (1.12.0)\n",
"Requirement already satisfied: matplotlib in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (3.8.0)\n",
"Requirement already satisfied: numpy in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.24.4)\n",
"Requirement already satisfied: scikit-learn in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.3.0)\n",
"Requirement already satisfied: scikit-optimize>=0.9.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (0.10.2)\n",
"Requirement already satisfied: scipy in /srv/conda/envs/notebook/lib/python3.11/site-packages (from deepxde) (1.11.2)\n",
"Requirement already satisfied: joblib>=0.11 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (1.3.2)\n",
"Requirement already satisfied: pyaml>=16.9 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (24.7.0)\n",
"Requirement already satisfied: packaging>=21.3 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-optimize>=0.9.0->deepxde) (23.1)\n",
"Requirement already satisfied: threadpoolctl>=2.0.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from scikit-learn->deepxde) (3.2.0)\n",
"Requirement already satisfied: contourpy>=1.0.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (1.1.1)\n",
"Requirement already satisfied: cycler>=0.10 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (0.11.0)\n",
"Requirement already satisfied: fonttools>=4.22.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (4.42.1)\n",
"Requirement already satisfied: kiwisolver>=1.0.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (1.4.5)\n",
"Requirement already satisfied: pillow>=6.2.0 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (9.5.0)\n",
"Requirement already satisfied: pyparsing>=2.3.1 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (3.1.1)\n",
"Requirement already satisfied: python-dateutil>=2.7 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from matplotlib->deepxde) (2.8.2)\n",
"Requirement already satisfied: PyYAML in /srv/conda/envs/notebook/lib/python3.11/site-packages (from pyaml>=16.9->scikit-optimize>=0.9.0->deepxde) (6.0.1)\n",
"Requirement already satisfied: six>=1.5 in /srv/conda/envs/notebook/lib/python3.11/site-packages (from python-dateutil>=2.7->matplotlib->deepxde) (1.16.0)\n",
"Note: you may need to restart the kernel to use updated packages.\n"
]
}
],
"source": [
"pip install deepxde"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {
"scrolled": true
},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Using backend: pytorch\n",
"Other supported backends: tensorflow.compat.v1, tensorflow, jax, paddle.\n",
"paddle supports more examples now and is recommended.\n"
]
}
],
"source": [
"import numpy as np\n",
"import xarray as xr\n",
"import os\n",
"os.environ[\"DDEBACKEND\"] = \"pytorch\"\n",
"import deepxde as dde\n",
"import matplotlib.pyplot as plt\n",
"import cartopy.crs as ccrs\n",
"import cartopy.feature as cfeature\n",
"from tqdm import tqdm\n",
"import numpy as np\n",
"import xarray as xr\n",
"import deepxde as dde\n",
"\n",
"# Load the data\n",
"zarr_path = '~/shared-public/mind_the_chl_gap/U-Net_with_CHL_pred.zarr'\n",
"zarr_ds = xr.open_zarr(zarr_path)['gapfree_pred']\n",
"\n",
"level3_path = '~/shared-public/mind_the_chl_gap/IO.zarr'\n",
"level3_ds = xr.open_zarr(level3_path)\n",
"\n",
"# Prepare the input data (v)\n",
"variables = ['CHL_cmes-level3', 'sst', 'u_wind', 'v_wind', 'air_temp', 'ug_curr']\n",
"input_data = []\n",
"\n",
"for var in variables:\n",
" data = level3_ds[var].values\n",
" data = np.log(data) if var == 'CHL_cmes-level3' else data\n",
" input_data.append(data)\n",
"\n",
"v = np.stack(input_data, axis=-1)\n",
"\n",
"# Prepare the output data (u)\n",
"u = zarr_ds.values\n",
"\n",
"# Reshape the data\n",
"v = v.reshape(-1, v.shape[-1]) # (num_points, num_variables)\n",
"u = u.reshape(-1) # (num_points,)\n",
"\n",
"# Split the data into training and testing sets\n",
"n_train = int(0.8 * len(u))\n",
"X_train, y_train = (v[:n_train], np.zeros((n_train, 1))), u[:n_train]\n",
"X_test, y_test = (v[n_train:], np.zeros((len(u) - n_train, 1))), u[n_train:]\n",
"# Set up the data\n",
"data = dde.data.TripleCartesianProd(\n",
" X_train=X_train, y_train=y_train, X_test=X_test, y_test=y_test\n",
")\n",
"\n",
"# Define the DeepONet architecture\n",
"m = v.shape[1] # number of input variables\n",
"dim_x = 1 # dimension of spatial input (in this case, just a placeholder)\n",
"\n",
"net = dde.nn.DeepONetCartesianProd(\n",
" [m, 64, 64], # branch net\n",
" [dim_x, 64, 64], # trunk net\n",
" \"relu\",\n",
" \"Glorot normal\",\n",
")\n",
"\n",
"# Create the model\n",
"model = dde.Model(data, net)\n",
"\n",
"# Compile the model\n",
"model.compile(\"adam\", lr=0.001, metrics=[\"mean l2 relative error\"])\n",
"\n",
"# Create a custom loss function\n",
"def custom_loss(inputs, outputs, targets):\n",
" return torch.mean((outputs - targets)**2)\n",
"\n",
"# Create the model\n",
"model = dde.Model(data, net)\n",
"print(v.dtype)\n",
"print(x.dtype)\n",
"print(u.dtype)\n",
"\n",
"# Compile the model\n",
"model.compile(\"adam\", lr=0.001, loss=custom_loss, metrics=[\"mean l2 relative error\"])\n",
"\n",
"print(\"Training the model...\")\n",
"\n",
"# Train the model\n",
"losshistory, train_state = model.train(iterations=10000, batch_size=32) \n",
"\n",
"print(\"Making predictions...\")\n",
"\n",
"# Make predictions\n",
"y_pred = model.predict(X_test)\n",
"\n",
"# Reshape the predictions back to the original shape\n",
"y_pred = y_pred.reshape(zarr_ds.shape[1:])\n",
"\n",
"print(\"Visualizing results...\")\n",
"\n",
"# Visualize the results\n",
"fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(20, 10), subplot_kw={'projection': ccrs.PlateCarree()})\n",
"\n",
"# True gap-filled CHL\n",
"im1 = ax1.imshow(zarr_ds.isel(time=0), extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax1.set_title('True Gap-filled CHL')\n",
"ax1.add_feature(cfeature.COASTLINE)\n",
"ax1.set_xlabel('Longitude')\n",
"ax1.set_ylabel('Latitude')\n",
"fig.colorbar(im1, ax=ax1)\n",
"\n",
"# Predicted gap-filled CHL\n",
"im2 = ax2.imshow(y_pred, extent=(42, 101.75, -11.75, 32), origin='upper', transform=ccrs.PlateCarree())\n",
"ax2.set_title('Predicted Gap-filled CHL')\n",
"ax2.add_feature(cfeature.COASTLINE)\n",
"ax2.set_xlabel('Longitude')\n",
"ax2.set_ylabel('Latitude')\n",
"fig.colorbar(im2, ax=ax2)\n",
"\n",
"plt.show()\n",
"\n",
"print(\"Process completed.\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.11.5"
}
},
"nbformat": 4,
"nbformat_minor": 4
}

0 comments on commit 9919838

Please sign in to comment.