import sys
import dask
import xarray as xr
import numpy as np
import sklearn
from sklearn.ensemble import RandomForestClassifier as RF
from util import generate_X_y, generate_3d_dataset
ds.map_blocks(…)
.map_blocks(...)
applies a function to chunks of a dask-backed xarray.Dataset
. The following example demonstrates how ds.map_blocks(...)
can be used for pixel-wise application of a machine learning model.
print(sys.version)
print(dask.__version__)
print(xr.__version__)
print(np.__version__)
print(sklearn.__version__)
3.13.1 | packaged by conda-forge | (main, Dec 5 2024, 21:23:54) [GCC 13.3.0]
2025.4.0
2025.3.1
2.2.0
1.6.0
Generate new data to predict on
The time dimension in the following example is only a placeholder for any kind of predictor dimension. For the example to make sense (and work!), the predictor/feature (i.e., time) dimension must not be chunked internally, i.e., form a single chunk!
Code
= 2
n_classes = 12
n_features = 1000
n_samples
= 4000
lat = 6000
lon = n_features time
# random training data
= generate_X_y(n_samples,n_features, n_classes) X_train, y_train
# random features to predict on, in a "real" shape (x, y, time)
= generate_3d_dataset(lat, lon, time)
ds ds
<xarray.Dataset> Size: 2GB Dimensions: (lat: 4000, lon: 6000, time: 12) Coordinates: * lat (lat) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999 * lon (lon) int64 48kB 0 1 2 3 4 5 6 ... 5994 5995 5996 5997 5998 5999 * time (time) datetime64[ns] 96B 2021-01-01 2021-01-02 ... 2021-01-12 Data variables: test (lat, lon, time) float64 2GB dask.array<chunksize=(400, 600, 12), meta=np.ndarray>
Train a dummy model
= RF(random_state=42, n_estimators=50, n_jobs=-1)
rf rf.fit(X_train, y_train)
RandomForestClassifier(n_estimators=50, n_jobs=-1, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(n_estimators=50, n_jobs=-1, random_state=42)
Function for chunk-wise application
def generic_func(ds: xr.Dataset):
"""
Flatten chunk
Apply Random Forest model
Recover original 2D shape
"""
= ds.stack(ml=("lat", "lon")).transpose("ml", "time")
ds_stacked
# predict on input data
= ds_stacked.test.data
X = rf.predict(X)
y_hat_1d = y_hat_1d.reshape((ds.lat.size, ds.lon.size))
y_hat_2d
# copy the chunk but remove (squeeze) the time dimension
= ds.isel(time=[0]).squeeze().copy(deep=True)
data_out = y_hat_2d
data_out.test.data
return data_out
= ds.map_blocks(generic_func, template=ds.isel(time=[0]).squeeze()) ds_pred
ds_pred
<xarray.Dataset> Size: 192MB Dimensions: (lat: 4000, lon: 6000) Coordinates: * lat (lat) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999 * lon (lon) int64 48kB 0 1 2 3 4 5 6 ... 5994 5995 5996 5997 5998 5999 time datetime64[ns] 8B dask.array<chunksize=(), meta=np.ndarray> Data variables: test (lat, lon) float64 192MB dask.array<chunksize=(400, 600), meta=np.ndarray>
= ds_pred.compute() ds_pred