ds.map_blocks(…)

.map_blocks(...) applies a function to chunks of a dask-backed xarray.Dataset. The following example demonstrates how ds.map_blocks(...) can be used for pixel-wise application of a machine learning model.

import sys

import dask
import xarray as xr
import numpy as np
import sklearn
from sklearn.ensemble import RandomForestClassifier as RF

from util import generate_X_y, generate_3d_dataset
print(sys.version)
print(dask.__version__)
print(xr.__version__)
print(np.__version__)
print(sklearn.__version__)
3.13.1 | packaged by conda-forge | (main, Dec  5 2024, 21:23:54) [GCC 13.3.0]
2025.4.0
2025.3.1
2.2.0
1.6.0

Generate new data to predict on

The time dimension in the following example is only a placeholder for any kind of predictor dimension. For the example to make sense (and work!), the predictor/feature (i.e., time) dimension must not be chunked internally, i.e., form a single chunk!

Code
n_classes = 2
n_features = 12
n_samples = 1000

lat = 4000
lon = 6000 
time = n_features
# random training data
X_train, y_train = generate_X_y(n_samples,n_features, n_classes)
# random features to predict on, in a "real" shape (x, y, time)
ds = generate_3d_dataset(lat, lon, time)
ds
<xarray.Dataset> Size: 2GB
Dimensions:  (lat: 4000, lon: 6000, time: 12)
Coordinates:
  * lat      (lat) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999
  * lon      (lon) int64 48kB 0 1 2 3 4 5 6 ... 5994 5995 5996 5997 5998 5999
  * time     (time) datetime64[ns] 96B 2021-01-01 2021-01-02 ... 2021-01-12
Data variables:
    test     (lat, lon, time) float64 2GB dask.array<chunksize=(400, 600, 12), meta=np.ndarray>

Train a dummy model

rf = RF(random_state=42, n_estimators=50, n_jobs=-1)
rf.fit(X_train, y_train)
RandomForestClassifier(n_estimators=50, n_jobs=-1, random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Function for chunk-wise application

def generic_func(ds: xr.Dataset):
    """
    Flatten chunk
    Apply Random Forest model
    Recover original 2D shape
    """
    ds_stacked = ds.stack(ml=("lat", "lon")).transpose("ml", "time")

    # predict on input data
    X = ds_stacked.test.data
    y_hat_1d = rf.predict(X)
    y_hat_2d = y_hat_1d.reshape((ds.lat.size, ds.lon.size))

    # copy the chunk but remove (squeeze) the time dimension
    data_out = ds.isel(time=[0]).squeeze().copy(deep=True)
    data_out.test.data = y_hat_2d

    return data_out
ds_pred = ds.map_blocks(generic_func, template=ds.isel(time=[0]).squeeze())
ds_pred
<xarray.Dataset> Size: 192MB
Dimensions:  (lat: 4000, lon: 6000)
Coordinates:
  * lat      (lat) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999
  * lon      (lon) int64 48kB 0 1 2 3 4 5 6 ... 5994 5995 5996 5997 5998 5999
    time     datetime64[ns] 8B dask.array<chunksize=(), meta=np.ndarray>
Data variables:
    test     (lat, lon) float64 192MB dask.array<chunksize=(400, 600), meta=np.ndarray>
ds_pred = ds_pred.compute()