xr.apply_ufunc(…)

Example on how xr.apply_ufunc(...) can be used for pixel wise prediction.

Note

Note: ds.map_blocks() likely is faster compared to this approach. This is really only for demonstrative purpose, as a template to be adapted for other computations.

import sys

import xarray as xr
import numpy as np
import dask
import sklearn
from sklearn.ensemble import RandomForestClassifier as RF

from util import generate_X_y, generate_3d_dataset
print(sys.version)
print(dask.__version__)
print(xr.__version__)
print(np.__version__)
print(sklearn.__version__)
3.13.1 | packaged by conda-forge | (main, Dec  5 2024, 21:23:54) [GCC 13.3.0]
2025.4.0
2025.3.1
2.2.0
1.6.0

Generate some random data for training and inference

Code
n_classes = 2
n_features = 12
n_samples = 1000

lat = 40
lon = 60
time = n_features
# random training data
X_train, y_train = generate_X_y(n_samples,n_features, n_classes)
# random "real" data to predict on
ds = generate_3d_dataset(lat, lon, time)
ds
<xarray.Dataset> Size: 231kB
Dimensions:  (lat: 40, lon: 60, time: 12)
Coordinates:
  * lat      (lat) int64 320B 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37 38 39
  * lon      (lon) int64 480B 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59
  * time     (time) datetime64[ns] 96B 2021-01-01 2021-01-02 ... 2021-01-12
Data variables:
    test     (lat, lon, time) float64 230kB dask.array<chunksize=(4, 6, 12), meta=np.ndarray>

Train a dummy model

rf = RF(random_state=42, n_estimators=50, n_jobs=-1)
rf.fit(X_train, y_train)
RandomForestClassifier(n_estimators=50, n_jobs=-1, random_state=42)
In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.

Define function to be applied via .apply.ufunc(...)

def generic_func(arr):
    return rf.predict(arr.reshape(1, -1))
ds_ag = xr.apply_ufunc(
    generic_func,
    ds,
    input_core_dims=[["time"]],
    dask="parallelized",
    output_dtypes=np.float32,
    vectorize=True,
    dask_gufunc_kwargs={"allow_rechunk": True},
)
ds_ag
<xarray.Dataset> Size: 10kB
Dimensions:  (lat: 40, lon: 60)
Coordinates:
  * lat      (lat) int64 320B 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37 38 39
  * lon      (lon) int64 480B 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59
Data variables:
    test     (lat, lon) float32 10kB dask.array<chunksize=(4, 6), meta=np.ndarray>
ds_ag.compute()
<xarray.Dataset> Size: 10kB
Dimensions:  (lat: 40, lon: 60)
Coordinates:
  * lat      (lat) int64 320B 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37 38 39
  * lon      (lon) int64 480B 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59
Data variables:
    test     (lat, lon) float32 10kB 0.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 0.0 0.0