import sys
import xarray as xr
import numpy as np
import dask
import sklearn
from sklearn.ensemble import RandomForestClassifier as RF
from util import generate_X_y, generate_3d_dataset
xr.apply_ufunc(…)
Example on how xr.apply_ufunc(...)
can be used for pixel wise prediction.
Note
Note: ds.map_blocks()
likely is faster compared to this approach. This is really only for demonstrative purpose, as a template to be adapted for other computations.
print(sys.version)
print(dask.__version__)
print(xr.__version__)
print(np.__version__)
print(sklearn.__version__)
3.13.1 | packaged by conda-forge | (main, Dec 5 2024, 21:23:54) [GCC 13.3.0]
2025.4.0
2025.3.1
2.2.0
1.6.0
Generate some random data for training and inference
Code
= 2
n_classes = 12
n_features = 1000
n_samples
= 40
lat = 60
lon = n_features time
# random training data
= generate_X_y(n_samples,n_features, n_classes) X_train, y_train
# random "real" data to predict on
= generate_3d_dataset(lat, lon, time)
ds ds
<xarray.Dataset> Size: 231kB Dimensions: (lat: 40, lon: 60, time: 12) Coordinates: * lat (lat) int64 320B 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37 38 39 * lon (lon) int64 480B 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59 * time (time) datetime64[ns] 96B 2021-01-01 2021-01-02 ... 2021-01-12 Data variables: test (lat, lon, time) float64 230kB dask.array<chunksize=(4, 6, 12), meta=np.ndarray>
Train a dummy model
= RF(random_state=42, n_estimators=50, n_jobs=-1)
rf rf.fit(X_train, y_train)
RandomForestClassifier(n_estimators=50, n_jobs=-1, random_state=42)In a Jupyter environment, please rerun this cell to show the HTML representation or trust the notebook.
On GitHub, the HTML representation is unable to render, please try loading this page with nbviewer.org.
RandomForestClassifier(n_estimators=50, n_jobs=-1, random_state=42)
Define function to be applied via .apply.ufunc(...)
def generic_func(arr):
return rf.predict(arr.reshape(1, -1))
= xr.apply_ufunc(
ds_ag
generic_func,
ds,=[["time"]],
input_core_dims="parallelized",
dask=np.float32,
output_dtypes=True,
vectorize={"allow_rechunk": True},
dask_gufunc_kwargs )
ds_ag
<xarray.Dataset> Size: 10kB Dimensions: (lat: 40, lon: 60) Coordinates: * lat (lat) int64 320B 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37 38 39 * lon (lon) int64 480B 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59 Data variables: test (lat, lon) float32 10kB dask.array<chunksize=(4, 6), meta=np.ndarray>
ds_ag.compute()
<xarray.Dataset> Size: 10kB Dimensions: (lat: 40, lon: 60) Coordinates: * lat (lat) int64 320B 0 1 2 3 4 5 6 7 8 9 ... 31 32 33 34 35 36 37 38 39 * lon (lon) int64 480B 0 1 2 3 4 5 6 7 8 9 ... 51 52 53 54 55 56 57 58 59 Data variables: test (lat, lon) float32 10kB 0.0 1.0 1.0 1.0 1.0 ... 1.0 1.0 1.0 0.0 0.0