import sys
import dask
import xarray as xr
import numpy as np
import sklearn
from sklearn.ensemble import RandomForestClassifier as RF
from util import generate_X_y, generate_3d_datasetds.map_blocks(…)
.map_blocks(...) applies a function to chunks of a dask-backed xarray.Dataset. The following example demonstrates how ds.map_blocks(...) can be used for pixel-wise application of a machine learning model.
print(sys.version)
print(dask.__version__)
print(xr.__version__)
print(np.__version__)
print(sklearn.__version__)3.13.1 | packaged by conda-forge | (main, Dec 5 2024, 21:23:54) [GCC 13.3.0]
2025.4.0
2025.3.1
2.2.0
1.6.0
Generate new data to predict on
The time dimension in the following example is only a placeholder for any kind of predictor dimension. For the example to make sense (and work!), the predictor/feature (i.e., time) dimension must not be chunked internally, i.e., form a single chunk!
Code
n_classes = 2
n_features = 12
n_samples = 1000
lat = 4000
lon = 6000
time = n_features# random training data
X_train, y_train = generate_X_y(n_samples,n_features, n_classes)# random features to predict on, in a "real" shape (x, y, time)
ds = generate_3d_dataset(lat, lon, time)
ds<xarray.Dataset> Size: 2GB
Dimensions: (lat: 4000, lon: 6000, time: 12)
Coordinates:
* lat (lat) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999
* lon (lon) int64 48kB 0 1 2 3 4 5 6 ... 5994 5995 5996 5997 5998 5999
* time (time) datetime64[ns] 96B 2021-01-01 2021-01-02 ... 2021-01-12
Data variables:
test (lat, lon, time) float64 2GB dask.array<chunksize=(400, 600, 12), meta=np.ndarray>Train a dummy model
rf = RF(random_state=42, n_estimators=50, n_jobs=-1)
rf.fit(X_train, y_train)RandomForestClassifier(n_estimators=50, n_jobs=-1, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(n_estimators=50, n_jobs=-1, random_state=42)
Function for chunk-wise application
def generic_func(ds: xr.Dataset):
"""
Flatten chunk
Apply Random Forest model
Recover original 2D shape
"""
ds_stacked = ds.stack(ml=("lat", "lon")).transpose("ml", "time")
# predict on input data
X = ds_stacked.test.data
y_hat_1d = rf.predict(X)
y_hat_2d = y_hat_1d.reshape((ds.lat.size, ds.lon.size))
# copy the chunk but remove (squeeze) the time dimension
data_out = ds.isel(time=[0]).squeeze().copy(deep=True)
data_out.test.data = y_hat_2d
return data_outds_pred = ds.map_blocks(generic_func, template=ds.isel(time=[0]).squeeze())ds_pred<xarray.Dataset> Size: 192MB
Dimensions: (lat: 4000, lon: 6000)
Coordinates:
* lat (lat) int64 32kB 0 1 2 3 4 5 6 ... 3994 3995 3996 3997 3998 3999
* lon (lon) int64 48kB 0 1 2 3 4 5 6 ... 5994 5995 5996 5997 5998 5999
time datetime64[ns] 8B dask.array<chunksize=(), meta=np.ndarray>
Data variables:
test (lat, lon) float64 192MB dask.array<chunksize=(400, 600), meta=np.ndarray>ds_pred = ds_pred.compute()