Machine Learning

Dask is tightly integrated with common Python machine learning libraries like Scikit-Learn, XGBoost, LightGBM, Optuna, and PyTorch. This makes it easy to train models at scale, and run predictions on large scale data.

import xgboost as xgb
import dask.array as da
import coiled

cluster = coiled.Cluster(n_workers=20)
client = cluster.get_client()

# X and y must be Dask DataFrames or Arrays
# example with 1e5 observations and 20 features
X = da.random.random(size=(1e5, 20), chunks=(1000, 20))
y = da.random.random(size=(1e5, 1), chunks=(1000, 1))

dtrain = xgb.dask.DaskDMatrix(client, X, y)

output = xgb.dask.train(
    client,
    {"tree_method": "hist", "objective": "reg:squarederror"},
    dtrain,
    num_boost_round=4,
    evals=[(dtrain, "train")],
)

More Examples

For more in-depth machine learning examples consider the following: