{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Scaling XGBoost with Dask and Coiled\n", "\n", "[XGBoost](https://xgboost.readthedocs.io/en/latest/) is a library used for training gradient boosted supervised machine learning models. In this guide, you'll learn how to train an XGBoost model in parallel in your own cloud account using Dask and Coiled. Download {download}`this jupyter notebook ` to follow along." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Before you start\n", "\n", "You'll first need to create consistent local and remote software environments with `dask`, `coiled`, and the necessary dependencies installed. You can use [coiled-runtime](https://docs.coiled.io/user_guide/software_environment.html#coiled-runtime), a conda metapackage, which already includes `xgboost` and `dask-ml`.\n", "\n", "You can install `coiled-runtime` locally in a conda environment:\n", "\n", "```\n", "conda create -n xgboost-example -c conda-forge python=3.9 coiled-runtime\n", "```\n", "\n", "And activate the conda environment you just created:\n", "\n", "```\n", "conda activate xgboost-example\n", "```" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Launch your Coiled cluster\n", "\n", "Create a Dask cluster in your cloud account with Coiled:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import coiled\n", "\n", "cluster = coiled.Cluster(\n", " n_workers=5,\n", " name=\"xgboost-example\"\n", ")" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "And connect Dask to your remote Coiled cluster:" ] }, { "cell_type": "code", "execution_count": 91, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "
\n", "
\n", "

Client

\n", "

Client-82f2e276-0df8-11ed-9ee6-92492cdc1fe8

\n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", "
Connection method: Cluster objectCluster type: coiled.ClusterBeta
\n", " Dashboard: http://34.229.183.68:8787\n", "
\n", "\n", " \n", "
\n", "

Cluster Info

\n", "
\n", "
\n", "
\n", "
\n", "

ClusterBeta

\n", "

xgboost-example

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Dashboard: http://34.229.183.68:8787\n", " \n", " Workers: 10\n", "
\n", " Total threads: 20\n", " \n", " Total memory: 37.75 GiB\n", "
\n", "\n", "
\n", " \n", "

Scheduler Info

\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", "

Scheduler

\n", "

Scheduler-33869c9d-662e-485c-909b-a159a679ef48

\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
\n", " Comm: tls://10.0.9.70:8786\n", " \n", " Workers: 10\n", "
\n", " Dashboard: http://10.0.9.70:8787/status\n", " \n", " Total threads: 20\n", "
\n", " Started: 1 minute ago\n", " \n", " Total memory: 37.75 GiB\n", "
\n", "
\n", "
\n", "\n", "
\n", " \n", "

Workers

\n", "
\n", "\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-002c360a22

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.14.85:41987\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.14.85:45971/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.14.85:44077\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-2vxe2co5\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-1f44f01f94

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.14.74:40389\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.14.74:45349/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.14.74:43301\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-15p27tfu\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-26c9b66f31

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.13.65:34529\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.13.65:36785/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.13.65:33537\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-6dyjc8gm\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-394357c48b

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.1.91:40257\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.1.91:36887/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.1.91:40003\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-3jzg56vz\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-404f57b0cd

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.6.182:33137\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.6.182:39023/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.6.182:38051\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-xhxjb74z\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-8ea659633e

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.10.42:38521\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.10.42:39829/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.10.42:38643\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-bmw0qjz3\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-9e317174fc

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.13.106:40497\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.13.106:44907/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.13.106:33763\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-got9tlbu\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-a7d253b547

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.7.5:40039\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.7.5:44959/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.7.5:38337\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-vl888io3\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-cb323c6dce

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.10.116:42899\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.10.116:37847/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.10.116:42445\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-6xcej6hx\n", "
\n", "
\n", "
\n", "
\n", " \n", "
\n", "
\n", "
\n", "
\n", " \n", "

Worker: xgboost-example-worker-f7aa0e7137

\n", "
\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "\n", " \n", "\n", " \n", "\n", "
\n", " Comm: tls://10.0.14.233:36111\n", " \n", " Total threads: 2\n", "
\n", " Dashboard: http://10.0.14.233:43361/status\n", " \n", " Memory: 3.78 GiB\n", "
\n", " Nanny: tls://10.0.14.233:37535\n", "
\n", " Local directory: /scratch/dask-worker-space/worker-o5wsyu4f\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
\n", "\n", "
\n", "
\n", "
\n", "
\n", " \n", "\n", "
\n", "
" ], "text/plain": [ "" ] }, "execution_count": 91, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import dask.distributed\n", "\n", "client = dask.distributed.Client(cluster)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Train your model\n", "\n", "You’ll use the [Higgs dataset](https://archive.ics.uci.edu/ml/datasets/HIGGS) available on Amazon S3. This dataset is composed of 11 million simulated particle collisions, each of which is described by 28 real-valued features and a binary label indicating which class the sample belongs to (i.e. whether the sample represents a signal or background event).\n", "\n", "You'll use Dask's `read_csv` function makes to read in all the CSV files in the dataset:" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import dask.dataframe as dd\n", "\n", "# Load the entire dataset lazily using Dask\n", "ddf = dd.read_csv(\"s3://coiled-data/higgs/higgs-*.csv\", storage_options={\"anon\": True})" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "You can separate the classification label and training features and then partition the dataset into training and testing samples. Dask's machine learning library, [Dask-ML](https://ml.dask.org/), mimics Scikit-learn's API, providing scalable versions of `sklearn.datasets.make_classification` and `sklearn.model_selection.train_test_split` that are designed to work with Dask Arrays and DataFrames larger than available RAM." ] }, { "cell_type": "code", "execution_count": 93, "metadata": {}, "outputs": [], "source": [ "from dask_ml.model_selection import train_test_split\n", "\n", "X, y = ddf.iloc[:, 1:], ddf[\"labels\"]\n", "# use Dask-ML to generate test and train datasets\n", "X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.2, shuffle=True, random_state=2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next you'll persist your training and testing datasets into memory to avoid re-computations (see the Dask documentation for [best practices using *persist*](https://docs.dask.org/en/stable/best-practices.html#persist-when-you-can)):" ] }, { "cell_type": "code", "execution_count": 94, "metadata": {}, "outputs": [], "source": [ "import dask\n", "\n", "X_train, X_test, y_train, y_test = dask.persist(X_train, X_test, y_train, y_test)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "To do distributed training of an XGBoost model, you'll use XGBoost with Dask (see the XGBoost tutorial on [using XGBoost with Dask](https://xgboost.readthedocs.io/en/stable/tutorials/dask.html)). You’ll need to first construct the `xgboost.DMatrix` object for both your training and testing datasets – these are the internal data structures XGBoost uses to manage dataset features and targets. Since you're using XGBoost with Dask, you can pass your training and testing datasets directly to `xgboost.dask.DMatrix()`." ] }, { "cell_type": "code", "execution_count": 95, "metadata": {}, "outputs": [], "source": [ "import xgboost\n", "\n", "dtrain = xgboost.dask.DaskDMatrix(client=client, data=X_train, label=y_train)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Next you'll define the set of hyperparameters to use for the model and train the model (see the [XGBoost documentation on parameters](https://xgboost.readthedocs.io/en/stable/parameter.html)):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "params = {\n", " 'objective': 'binary:logistic',\n", " 'max_depth': 3,\n", " 'min_child_weight': 0.5,\n", " 'eval_metric': 'logloss'\n", "}\n", "\n", "bst = xgboost.dask.train(client, params, dtrain, num_boost_round=3)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Generate model predictions\n", "\n", "Now that your model has been trained, you can use it to make predictions on the testing dataset which was *not* used to train the model:" ] }, { "cell_type": "code", "execution_count": 96, "metadata": {}, "outputs": [], "source": [ "y_pred = xgboost.dask.predict(client, bst, X_test)\n", "y_test, y_pred = dask.compute(y_test, y_pred)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Voilà! Congratulations on training a boosted decision tree in the cloud.\n", "\n", "Once you're done, you can shutdown the cluster (it will shutdown automatically after 20 minutes of inactivity):" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "cluster.close()\n", "client.close()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Next steps\n", "\n", "For a more in-depth look at what you can do with XGBoost, Dask, and Coiled, check out [this Coiled blogpost](https://coiled.io/blog/dask-python-xgboost-example/)." ] } ], "metadata": { "kernelspec": { "display_name": "Python 3.9.13 ('dask-sql-example')", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.9.13" }, "vscode": { "interpreter": { "hash": "02c852c24d9f048ccdc209be0dc4985b81e663aaf523cefac5b7672a31b52420" } } }, "nbformat": 4, "nbformat_minor": 4 }