Gap-filling with CNN#

Author: Yifei Hang (UW Varanasi intern 2024)

This notebook shows how to fit a basic Convolutional Neural Network for filling the gaps in the Chlorophyll-a data. Although you can run this tutorial on CPU, it will be much faster on GPU. We used the image quay.io/pangeo/ml-notebook:2024.08.18 for running the notebook.

import xarray as xr
import numpy as np

import dask.array as da

import matplotlib.pyplot as plt

import tensorflow as tf
from keras.callbacks import EarlyStopping
from keras.models import Sequential
from keras.layers import ConvLSTM2D, BatchNormalization, Conv2D, Dropout
# list all the physical devices
physical_devices = tf.config.list_physical_devices()
print("All Physical Devices:", physical_devices)

# list all the available GPUs
gpus = tf.config.list_physical_devices('GPU')
print("Available GPUs:", gpus)

# Print infomation for available GPU if there exists any
if gpus:
    for gpu in gpus:
        details = tf.config.experimental.get_device_details(gpu)
        print("GPU Details:", details)
else:
    print("No GPU available")
All Physical Devices: [PhysicalDevice(name='/physical_device:CPU:0', device_type='CPU'), PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
Available GPUs: [PhysicalDevice(name='/physical_device:GPU:0', device_type='GPU')]
GPU Details: {'compute_capability': (7, 5), 'device_name': 'Tesla T4'}
2024-08-21 01:07:44.661600: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:07:44.720010: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:07:44.720279: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:07:44.721706: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
zarr_ds = xr.open_zarr("~/shared-public/mind_the_chl_gap/IO.zarr")
zarr_sliced = zarr_ds.sel(lat=slice(35, -5), lon=slice(45,90))
all_nan_CHL = np.isnan(zarr_sliced["CHL_cmes-level3"]).all(dim=["lon", "lat"]).compute()  # find sample indices where CHL is NaN
zarr_CHL = zarr_sliced.sel(time=(all_nan_CHL == False))  # select samples with CHL not NaN
zarr_CHL = zarr_CHL.sortby('time')
zarr_CHL = zarr_CHL.sel(time=slice('2020-01-01', '2020-12-31'))
zarr_CHL
<xarray.Dataset> Size: 958MB
Dimensions:                       (time: 366, lat: 149, lon: 181)
Coordinates:
  * lat                           (lat) float32 596B 32.0 31.75 ... -4.75 -5.0
  * lon                           (lon) float32 724B 45.0 45.25 ... 89.75 90.0
  * time                          (time) datetime64[ns] 3kB 2020-01-01 ... 20...
Data variables: (12/27)
    CHL                           (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    CHL_cmes-cloud                (time, lat, lon) uint8 10MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    CHL_cmes-gapfree              (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    CHL_cmes-land                 (lat, lon) uint8 27kB dask.array<chunksize=(149, 181), meta=np.ndarray>
    CHL_cmes-level3               (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    CHL_cmes_flags-gapfree        (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    ...                            ...
    ug_curr                       (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    v_curr                        (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    v_wind                        (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    vg_curr                       (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    wind_dir                      (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
    wind_speed                    (time, lat, lon) float32 39MB dask.array<chunksize=(23, 149, 181), meta=np.ndarray>
Attributes: (12/92)
    Conventions:                     CF-1.8, ACDD-1.3
    DPM_reference:                   GC-UD-ACRI-PUG
    IODD_reference:                  GC-UD-ACRI-PUG
    acknowledgement:                 The Licensees will ensure that original ...
    citation:                        The Licensees will ensure that original ...
    cmems_product_id:                OCEANCOLOUR_GLO_BGC_L3_MY_009_103
    ...                              ...
    time_coverage_end:               2024-04-18T02:58:23Z
    time_coverage_resolution:        P1D
    time_coverage_start:             2024-04-16T21:12:05Z
    title:                           cmems_obs-oc_glo_bgc-plankton_my_l3-mult...
    westernmost_longitude:           -180.0
    westernmost_valid_longitude:     -180.0
p = zarr_CHL.sel(time='2020-09-02').CHL.plot(y='lat', x='lon')
../_images/c5fc1dd9f70f95ec9343fe99be816542b82dfd887d970a79755bb32e07583ef8.png
# log scale
np.log(zarr_CHL.sel(time='2020-09-02').CHL).plot(y='lat', x='lon')
<matplotlib.collections.QuadMesh at 0x7f31a3345a30>
../_images/d632210d5ac89bec04ba956af9ca6ceddab61c5433bea51fed3afd443c7feb01.png
# xr.merge((zarr_CHL.CHL, zarr_CHL.sst))
zarr_test = zarr_CHL.copy()
zarr_test['CHL'] = np.log(zarr_CHL.CHL).copy()
# zarr_test
zarr_test.CHL[0].compute() - zarr_CHL.CHL[0].compute()
shape = (100, ) + zarr_test.CHL.shape[1:]
# xr.array(zarr_test.CHL, chunks=shape)
zarr_test.sel(time=slice('2020-01-01', '2020-9-30'))
zarr_test['CHL'] = (('time', 'lat', 'lon'), da.where(da.isnan(zarr_CHL.CHL), 10E-10, zarr_CHL.CHL))
a = ['1', 'b']
a+['c']
['1', 'b', 'c']
zarr_CHL.CHL_uncertainty.max().compute()
<xarray.DataArray 'CHL_uncertainty' ()> Size: 4B
array(127.318726, dtype=float32)

Process the data#

We need to split into our training and testing data.

def log_label(data, label):
    data_logged = data.copy()
    data_logged[label] = np.log(data[label]).copy()
    return data_logged

# Add more preprocessing later
def preprocess_data(data, features, label):
    # sel_data = data[features + label]
    # sel_data = da.where(da.isnan(sel_data), 0.0, sel_data)
    data_logged = log_label(data, label)

    sel_data_list = []
    for var in (features + [label]):
        sel_var_data = data_logged[var]
        sel_var_data = da.where(da.isnan(sel_var_data), 0.0, sel_var_data)
        sel_data_list.append(sel_var_data)
        # print(data[var]).shape
    # sel_data_da = da.from_array(sel_data_arr)
    # sel_data = da.where(da.isnan(sel_data_da), 0.0, sel_data_da)
    sel_data_da =  da.array(sel_data_list)
    # sel_data_da = np.moveaxis(sel_data_da, 0, -1)
    return sel_data_da

def time_series_split(data, split_ratio):
    X = data[:-1]
    y = data[-1]
    
    X = np.moveaxis(X, 0, -1)

    total_length = X.shape[0]
    
    train_end = int(total_length * split_ratio[0])
    val_end = int(total_length * (split_ratio[0] + split_ratio[1]))

    X_train, y_train = X[:train_end], y[:train_end]
    X_val, y_val = X[train_end: val_end], y[train_end: val_end]
    X_test, y_test = X[val_end:], y[val_end:]

    return (X_train, y_train,
            X_val, y_val,
            X_test, y_test)

Here we create our training and test data with 2 variables using only 2020. 70% data for training, 20% for validation and 10% for testing.

 # Curr Features: Sea Surface Temp (K), Sea Salinity Concentration (m**-3 or PSL). [Excluding Topography/Bathymetry (m)]
features = ['sst', 'so'] 
label = 'CHL_cmes-level3'  # chlorophyll-a concentration (mg/m**3) [Not taking uncertainty into consideration for now]
model_data = preprocess_data(zarr_CHL, features, label)

split_ratio = [.7, .2, .1]
X_train, y_train, X_val, y_val, X_test, y_test = time_series_split(model_data, split_ratio)

Create the CNN model#

def create_model_CNN(input_shape=(149, 181, 2)):
    model = Sequential()
    
    model.add(Conv2D(filters=64, 
                     kernel_size=(3, 3), 
                     input_shape=input_shape, 
                     padding='same',
                     activation='relu'
                     ))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    model.add(Conv2D(filters=32, 
                     kernel_size=(3, 3), 
                     padding='same',
                     activation='relu'
                     ))
    model.add(BatchNormalization())
    model.add(Dropout(0.2))

    model.add(Conv2D(filters=1, 
                     kernel_size=(3, 3), 
                     padding='same', 
                     activation='linear'
                     ))
    
    return model

model = create_model_CNN()
model.summary()
/srv/conda/envs/notebook/lib/python3.12/site-packages/keras/src/layers/convolutional/base_conv.py:107: UserWarning: Do not pass an `input_shape`/`input_dim` argument to a layer. When using Sequential models, prefer using an `Input(shape)` object as the first layer in the model instead.
  super().__init__(activity_regularizer=activity_regularizer, **kwargs)
2024-08-21 01:18:29.632682: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:18:29.632981: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:18:29.633137: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:18:30.514437: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:18:30.514719: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:18:30.514894: I external/local_xla/xla/stream_executor/cuda/cuda_executor.cc:998] successful NUMA node read from SysFS had negative value (-1), but there must be at least one NUMA node, so returning NUMA node zero. See more at https://github.com/torvalds/linux/blob/v6.0/Documentation/ABI/testing/sysfs-bus-pci#L344-L355
2024-08-21 01:18:30.515002: I tensorflow/core/common_runtime/gpu/gpu_device.cc:1928] Created device /job:localhost/replica:0/task:0/device:GPU:0 with 13949 MB memory:  -> device: 0, name: Tesla T4, pci bus id: 0000:00:1e.0, compute capability: 7.5
Model: "sequential"
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━┓
┃ Layer (type)                     Output Shape                  Param # ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━┩
│ conv2d (Conv2D)                 │ (None, 149, 181, 64)   │         1,216 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization             │ (None, 149, 181, 64)   │           256 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout (Dropout)               │ (None, 149, 181, 64)   │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_1 (Conv2D)               │ (None, 149, 181, 32)   │        18,464 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ batch_normalization_1           │ (None, 149, 181, 32)   │           128 │
│ (BatchNormalization)            │                        │               │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ dropout_1 (Dropout)             │ (None, 149, 181, 32)   │             0 │
├─────────────────────────────────┼────────────────────────┼───────────────┤
│ conv2d_2 (Conv2D)               │ (None, 149, 181, 1)    │           289 │
└─────────────────────────────────┴────────────────────────┴───────────────┘
 Total params: 20,353 (79.50 KB)
 Trainable params: 20,161 (78.75 KB)
 Non-trainable params: 192 (768.00 B)
model.compile(optimizer='adam', loss='mae', metrics=['mae'])

early_stop = EarlyStopping(patience=10, restore_best_weights=True)

train_dataset = tf.data.Dataset.from_tensor_slices((X_train, y_train))
train_dataset = train_dataset.shuffle(buffer_size=1024).batch(8)

val_dataset = tf.data.Dataset.from_tensor_slices((X_val, y_val))
val_dataset = val_dataset.batch(8)

history = model.fit(train_dataset, epochs=50, validation_data=val_dataset, callbacks=[early_stop])
Epoch 1/50
2024-08-21 01:19:14.976264: I tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.cc:268] disabling MLIR crash reproducer, set env var `MLIR_CRASH_REPRODUCER_DIRECTORY` to enable.
2024-08-21 01:19:15.167716: I external/local_xla/xla/stream_executor/cuda/cuda_dnn.cc:465] Loaded cuDNN version 8907
32/32 ━━━━━━━━━━━━━━━━━━━━ 9s 38ms/step - loss: 1.1156 - mae: 1.1156 - val_loss: 1.0242 - val_mae: 1.0242
Epoch 2/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.8378 - mae: 0.8378 - val_loss: 0.6492 - val_mae: 0.6492
Epoch 3/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.7707 - mae: 0.7707 - val_loss: 0.5647 - val_mae: 0.5647
Epoch 4/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.7562 - mae: 0.7562 - val_loss: 0.5580 - val_mae: 0.5580
Epoch 5/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.7154 - mae: 0.7154 - val_loss: 0.5441 - val_mae: 0.5441
Epoch 6/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.6769 - mae: 0.6769 - val_loss: 0.5646 - val_mae: 0.5646
Epoch 7/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.6749 - mae: 0.6749 - val_loss: 0.5613 - val_mae: 0.5613
Epoch 8/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.6768 - mae: 0.6768 - val_loss: 0.5505 - val_mae: 0.5505
Epoch 9/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.6273 - mae: 0.6273 - val_loss: 0.5401 - val_mae: 0.5401
Epoch 10/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.6231 - mae: 0.6231 - val_loss: 0.5355 - val_mae: 0.5355
Epoch 11/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.6096 - mae: 0.6096 - val_loss: 0.5301 - val_mae: 0.5301
Epoch 12/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.6194 - mae: 0.6194 - val_loss: 0.5297 - val_mae: 0.5297
Epoch 13/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5972 - mae: 0.5972 - val_loss: 0.5263 - val_mae: 0.5263
Epoch 14/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5797 - mae: 0.5797 - val_loss: 0.5257 - val_mae: 0.5257
Epoch 15/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5738 - mae: 0.5738 - val_loss: 0.5238 - val_mae: 0.5238
Epoch 16/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5855 - mae: 0.5855 - val_loss: 0.5233 - val_mae: 0.5233
Epoch 17/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5712 - mae: 0.5712 - val_loss: 0.5237 - val_mae: 0.5237
Epoch 18/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5904 - mae: 0.5904 - val_loss: 0.5287 - val_mae: 0.5287
Epoch 19/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5669 - mae: 0.5669 - val_loss: 0.5245 - val_mae: 0.5245
Epoch 20/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5600 - mae: 0.5600 - val_loss: 0.5226 - val_mae: 0.5226
Epoch 21/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5868 - mae: 0.5868 - val_loss: 0.5269 - val_mae: 0.5269
Epoch 22/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5809 - mae: 0.5809 - val_loss: 0.5224 - val_mae: 0.5224
Epoch 23/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5890 - mae: 0.5890 - val_loss: 0.5235 - val_mae: 0.5235
Epoch 24/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5930 - mae: 0.5930 - val_loss: 0.5217 - val_mae: 0.5217
Epoch 25/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5792 - mae: 0.5792 - val_loss: 0.5237 - val_mae: 0.5237
Epoch 26/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5630 - mae: 0.5630 - val_loss: 0.5237 - val_mae: 0.5237
Epoch 27/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5707 - mae: 0.5707 - val_loss: 0.5217 - val_mae: 0.5217
Epoch 28/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.6084 - mae: 0.6084 - val_loss: 0.5251 - val_mae: 0.5251
Epoch 29/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5886 - mae: 0.5886 - val_loss: 0.5233 - val_mae: 0.5233
Epoch 30/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5761 - mae: 0.5761 - val_loss: 0.5217 - val_mae: 0.5217
Epoch 31/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5830 - mae: 0.5830 - val_loss: 0.5227 - val_mae: 0.5227
Epoch 32/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5776 - mae: 0.5776 - val_loss: 0.5221 - val_mae: 0.5221
Epoch 33/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5871 - mae: 0.5871 - val_loss: 0.5226 - val_mae: 0.5226
Epoch 34/50
32/32 ━━━━━━━━━━━━━━━━━━━━ 1s 17ms/step - loss: 0.5803 - mae: 0.5803 - val_loss: 0.5228 - val_mae: 0.5228

Plot training & validation loss values#

plt.figure(figsize=(10, 6))
plt.plot(history.history['loss'], label='Train Loss')
plt.plot(history.history['val_loss'], label='Validation Loss')
plt.title('Model Loss')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend(loc='upper right')
plt.grid(True)
plt.show()

# Plot training & validation MAE values
plt.figure(figsize=(10, 6))
plt.plot(history.history['mae'], label='Train MAE')
plt.plot(history.history['val_mae'], label='Validation MAE')
plt.title('Model Mean Absolute Error (MAE)')
plt.xlabel('Epoch')
plt.ylabel('MAE')
plt.legend(loc='upper right')
plt.grid(True)
plt.show()
../_images/89fe009dd720060fd9a1172e06f174c855fb007911656d1f8d0799c0280ddb7f.png ../_images/0acbdbd14609caefcabb47104422734df12437e081dbbacb79feceba02544907.png

Prepare test dataset#

test_dataset = tf.data.Dataset.from_tensor_slices((X_test, y_test))
test_dataset = test_dataset.batch(4)

# Evaluate the model on the test dataset
test_loss, test_mae = model.evaluate(test_dataset)
print(f"Test Loss: {test_loss}")
print(f"Test MAE: {test_mae}")
10/10 ━━━━━━━━━━━━━━━━━━━━ 1s 3ms/step - loss: 0.6469 - mae: 0.6469  
Test Loss: 0.6140401363372803
Test MAE: 0.6140401363372803
date_to_predict = '2020-09-02'
date_index = (np.datetime64(date_to_predict) - np.datetime64('2020-01-01')).item().days
true_output = np.log(zarr_CHL.sel(time=date_to_predict).CHL)
input = np.moveaxis(model_data[:-1], 0, -1)[date_index]
input = np.array(input)
predicted_output = model.predict(input[np.newaxis, ...])[0]

predicted_output = predicted_output[:,:,0]

land_mask = np.load(r"E:\24SU Varanasi Intern\参考资料\2023_Intern_Material\land_mask_nc.npy")
predicted_output[land_mask] = np.nan

# true_output_2 = (model_data[-1])[date_index]

vmax = np.nanmax((true_output, predicted_output))
vmin = np.nanmin((true_output, predicted_output))

plt.imshow(true_output, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.title(f'True CHL on {date_to_predict}')
plt.show()




plt.imshow(predicted_output, vmin=vmin, vmax=vmax)
plt.colorbar()
plt.title(f'Predicted CHL on {date_to_predict}')
plt.show()
1/1 [==============================] - 0s 9ms/step
../_images/7c44154476e70db27d05918653b0724760b2e070833c38c28a4af408ae0d642b.png ../_images/49f93bf3f731cbd509f837de71625535ec48d403e306b2898c4795b19cde3a44.png
def compute_mae(y_true, y_pred):
    mask = ~np.isnan(y_true) & ~np.isnan(y_pred)
    return np.mean(np.abs(y_true[mask] - y_pred[mask]))


predicted_mae = compute_mae(np.array(true_output), predicted_output)
print(f"MAE between Predicted Output and True Output: {predicted_mae}")

# last_input_frame = input_data[-1]
# last_input_frame_2d = np.squeeze(last_input_frame)
# true_output_2d = np.squeeze(true_output)
# last_frame_mae = compute_mae(true_output_2d, last_input_frame_2d)
# print(f"MAE between Last Input Frame and True Output: {last_frame_mae}")
MAE between Predicted Output and True Output: 0.5261644721031189
# (np.datetime64('2020-01-03') - np.datetime64('2020-01-01')) / np.timedelta64(1, 'D')
(np.datetime64('2020-01-03') - np.datetime64('2020-01-01')).item().days
2