from __future__ import annotations
import asyncio
import datetime
import logging
import uuid
import warnings
from contextlib import suppress
from typing import Any
from packaging.version import parse as parse_version
from tornado.ioloop import IOLoop
import dask.config
from dask.utils import _deprecated, format_bytes, parse_timedelta, typename
from dask.widgets import get_template
from distributed.compatibility import PeriodicCallback
from distributed.core import Status
from distributed.deploy.adaptive import Adaptive
from distributed.metrics import time
from distributed.objects import SchedulerInfo
from distributed.utils import (
Log,
Logs,
LoopRunner,
NoOpAwaitable,
SyncMethodMixin,
format_dashboard_link,
log_errors,
)
logger = logging.getLogger(__name__)
class Cluster(SyncMethodMixin):
"""Superclass for cluster objects
This class contains common functionality for Dask Cluster manager classes.
To implement this class, you must provide
1. A ``scheduler_comm`` attribute, which is a connection to the scheduler
following the ``distributed.core.rpc`` API.
2. Implement ``scale``, which takes an integer and scales the cluster to
that many workers, or else set ``_supports_scaling`` to False
For that, you should get the following:
1. A standard ``__repr__``
2. A live IPython widget
3. Adaptive scaling
4. Integration with dask-labextension
5. A ``scheduler_info`` attribute which contains an up-to-date copy of
``Scheduler.identity()``, which is used for much of the above
6. Methods to gather logs
"""
_supports_scaling = True
__loop: IOLoop | None = None
def __init__(
self,
asynchronous=False,
loop=None,
quiet=False,
name=None,
scheduler_sync_interval=1,
):
self._loop_runner = LoopRunner(loop=loop, asynchronous=asynchronous)
self.__asynchronous = asynchronous
self.scheduler_info = {"workers": {}}
self.periodic_callbacks = {}
self._watch_worker_status_comm = None
self._watch_worker_status_task = None
self._cluster_manager_logs = []
self.quiet = quiet
self.scheduler_comm = None
self._adaptive = None
self._sync_interval = parse_timedelta(
scheduler_sync_interval, default="seconds"
)
self._sync_cluster_info_task = None
if name is None:
name = str(uuid.uuid4())[:8]
self._cluster_info = {
"name": name,
"type": typename(type(self)),
}
self.status = Status.created
@property
def loop(self) -> IOLoop | None:
loop = self.__loop
if loop is None:
# If the loop is not running when this is called, the LoopRunner.loop
# property will raise a DeprecationWarning
# However subsequent calls might occur - eg atexit, where a stopped
# loop is still acceptable - so we cache access to the loop.
self.__loop = loop = self._loop_runner.loop
return loop
@loop.setter
def loop(self, value: IOLoop) -> None:
warnings.warn(
"setting the loop property is deprecated", DeprecationWarning, stacklevel=2
)
if value is None:
raise ValueError("expected an IOLoop, got None")
self.__loop = value
@property
def called_from_running_loop(self):
try:
return (
getattr(self.loop, "asyncio_loop", None) is asyncio.get_running_loop()
)
except RuntimeError:
return self.__asynchronous
@property
def name(self):
return self._cluster_info["name"]
@name.setter
def name(self, name):
self._cluster_info["name"] = name
async def _start(self):
comm = await self.scheduler_comm.live_comm()
comm.name = "Cluster worker status"
await comm.write({"op": "subscribe_worker_status"})
self.scheduler_info = SchedulerInfo(await comm.read())
self._watch_worker_status_comm = comm
self._watch_worker_status_task = asyncio.ensure_future(
self._watch_worker_status(comm)
)
info = await self.scheduler_comm.get_metadata(
keys=["cluster-manager-info"], default={}
)
self._cluster_info.update(info)
# Start a background task for syncing cluster info with the scheduler
self._sync_cluster_info_task = asyncio.ensure_future(self._sync_cluster_info())
for pc in self.periodic_callbacks.values():
pc.start()
self.status = Status.running
async def _sync_cluster_info(self):
err_count = 0
warn_at = 5
max_interval = 10 * self._sync_interval
# Loop until the cluster is shutting down. We shouldn't really need
# this check (the `CancelledError` should be enough), but something
# deep in the comms code is silencing `CancelledError`s _some_ of the
# time, resulting in a cancellation not always bubbling back up to
# here. Relying on the status is fine though, not worth changing.
while self.status == Status.running:
try:
await self.scheduler_comm.set_metadata(
keys=["cluster-manager-info"],
value=self._cluster_info.copy(),
)
err_count = 0
except Exception:
err_count += 1
# Only warn if multiple subsequent attempts fail, and only once
# per set of subsequent failed attempts. This way we're not
# excessively noisy during a connection blip, but we also don't
# silently fail.
if err_count == warn_at:
logger.warning(
"Failed to sync cluster info multiple times - perhaps "
"there's a connection issue? Error:",
exc_info=True,
)
# Sleep, with error backoff
interval = _exponential_backoff(
err_count, self._sync_interval, 1.5, max_interval
)
await asyncio.sleep(interval)
async def _close(self):
if self.status == Status.closed:
return
self.status = Status.closing
with suppress(AttributeError):
self._adaptive.stop()
if self._watch_worker_status_comm:
await self._watch_worker_status_comm.close()
if self._watch_worker_status_task:
await self._watch_worker_status_task
if self._sync_cluster_info_task:
self._sync_cluster_info_task.cancel()
with suppress(asyncio.CancelledError):
await self._sync_cluster_info_task
if self.scheduler_comm:
await self.scheduler_comm.close_rpc()
for pc in self.periodic_callbacks.values():
pc.stop()
self.status = Status.closed
def close(self, timeout: float | None = None) -> Any:
# If the cluster is already closed, we're already done
if self.status == Status.closed:
if self.asynchronous:
return NoOpAwaitable()
return None
try:
return self.sync(self._close, callback_timeout=timeout)
except RuntimeError: # loop closed during process shutdown
return None
def __del__(self, _warn=warnings.warn):
if getattr(self, "status", Status.closed) != Status.closed:
try:
self_r = repr(self)
except Exception:
self_r = f"with a broken __repr__ {object.__repr__(self)}"
_warn(f"unclosed cluster {self_r}", ResourceWarning, source=self)
async def _watch_worker_status(self, comm):
"""Listen to scheduler for updates on adding and removing workers"""
while True:
try:
msgs = await comm.read()
except OSError:
break
with log_errors():
for op, msg in msgs:
self._update_worker_status(op, msg)
await comm.close()
def _update_worker_status(self, op, msg):
if op == "add":
workers = msg.pop("workers")
self.scheduler_info["workers"].update(workers)
self.scheduler_info.update(msg)
elif op == "remove":
del self.scheduler_info["workers"][msg]
else: # pragma: no cover
raise ValueError("Invalid op", op, msg)
def adapt(self, Adaptive: type[Adaptive] = Adaptive, **kwargs: Any) -> Adaptive:
"""Turn on adaptivity
For keyword arguments see dask.distributed.Adaptive
Examples
--------
>>> cluster.adapt(minimum=0, maximum=10, interval='500ms')
"""
with suppress(AttributeError):
self._adaptive.stop()
if not hasattr(self, "_adaptive_options"):
self._adaptive_options = {}
self._adaptive_options.update(kwargs)
self._adaptive = Adaptive(self, **self._adaptive_options)
return self._adaptive
def scale(self, n: int) -> None:
"""Scale cluster to n workers
Parameters
----------
n : int
Target number of workers
Examples
--------
>>> cluster.scale(10) # scale cluster to ten workers
"""
raise NotImplementedError()
def _log(self, log):
"""Log a message.
Output a message to the user and also store for future retrieval.
For use in subclasses where initialisation may take a while and it would
be beneficial to feed back to the user.
Examples
--------
>>> self._log("Submitted job X to batch scheduler")
"""
self._cluster_manager_logs.append((datetime.datetime.now(), log))
if not self.quiet:
print(log)
async def _get_logs(self, cluster=True, scheduler=True, workers=True):
logs = Logs()
if cluster:
logs["Cluster"] = Log(
"\n".join(line[1] for line in self._cluster_manager_logs)
)
if scheduler:
L = await self.scheduler_comm.get_logs()
logs["Scheduler"] = Log("\n".join(line for level, line in L))
if workers:
if workers is True:
workers = None
d = await self.scheduler_comm.worker_logs(workers=workers)
for k, v in d.items():
logs[k] = Log("\n".join(line for level, line in v))
return logs
def get_logs(self, cluster=True, scheduler=True, workers=True):
"""Return logs for the cluster, scheduler and workers
Parameters
----------
cluster : boolean
Whether or not to collect logs for the cluster manager
scheduler : boolean
Whether or not to collect logs for the scheduler
workers : boolean or Iterable[str], optional
A list of worker addresses to select.
Defaults to all workers if `True` or no workers if `False`
Returns
-------
logs: Dict[str]
A dictionary of logs, with one item for the scheduler and one for
each worker
"""
return self.sync(
self._get_logs, cluster=cluster, scheduler=scheduler, workers=workers
)
@_deprecated(use_instead="get_logs")
def logs(self, *args, **kwargs):
return self.get_logs(*args, **kwargs)
[docs]
def get_client(self):
"""Return client for the cluster
If a client has already been initialized for the cluster, return that
otherwise initialize a new client object.
"""
from distributed.client import Client
try:
current_client = Client.current()
if current_client and current_client.cluster == self:
return current_client
except ValueError:
pass
return Client(self)
@property
def dashboard_link(self):
try:
port = self.scheduler_info["services"]["dashboard"]
except KeyError:
return ""
else:
host = self.scheduler_address.split("://")[1].split("/")[0].split(":")[0]
return format_dashboard_link(host, port)
def _scaling_status(self):
if self._adaptive and self._adaptive.periodic_callback:
mode = "Adaptive"
else:
mode = "Manual"
workers = len(self.scheduler_info["workers"])
if hasattr(self, "worker_spec"):
requested = sum(
1 if "group" not in each else len(each["group"])
for each in self.worker_spec.values()
)
elif hasattr(self, "workers"):
requested = len(self.workers)
else:
requested = workers
worker_count = workers if workers == requested else f"{workers} / {requested}"
return f"""
<table>
<tr><td style="text-align: left;">Scaling mode: {mode}</td></tr>
<tr><td style="text-align: left;">Workers: {worker_count}</td></tr>
</table>
"""
def _widget(self):
"""Create IPython widget for display within a notebook"""
try:
return self._cached_widget
except AttributeError:
pass
try:
from ipywidgets import (
HTML,
Accordion,
Button,
HBox,
IntText,
Layout,
Tab,
VBox,
)
except ImportError:
self._cached_widget = None
return None
layout = Layout(width="150px")
status = HTML(self._repr_html_())
if self._supports_scaling:
request = IntText(0, description="Workers", layout=layout)
scale = Button(description="Scale", layout=layout)
minimum = IntText(0, description="Minimum", layout=layout)
maximum = IntText(0, description="Maximum", layout=layout)
adapt = Button(description="Adapt", layout=layout)
accordion = Accordion(
[HBox([request, scale]), HBox([minimum, maximum, adapt])],
layout=Layout(min_width="500px"),
)
accordion.selected_index = None
accordion.set_title(0, "Manual Scaling")
accordion.set_title(1, "Adaptive Scaling")
def adapt_cb(b):
self.adapt(minimum=minimum.value, maximum=maximum.value)
update()
adapt.on_click(adapt_cb)
@log_errors
def scale_cb(b):
n = request.value
with suppress(AttributeError):
self._adaptive.stop()
self.scale(n)
update()
scale.on_click(scale_cb)
else: # pragma: no cover
accordion = HTML("")
scale_status = HTML(self._scaling_status())
tab = Tab()
tab.children = [status, VBox([scale_status, accordion])]
tab.set_title(0, "Status")
tab.set_title(1, "Scaling")
self._cached_widget = tab
def update():
status.value = self._repr_html_()
scale_status.value = self._scaling_status()
cluster_repr_interval = parse_timedelta(
dask.config.get("distributed.deploy.cluster-repr-interval", default="ms")
)
def install():
pc = PeriodicCallback(update, cluster_repr_interval * 1000)
self.periodic_callbacks["cluster-repr"] = pc
pc.start()
self.loop.add_callback(install)
return tab
def _repr_html_(self, cluster_status=None):
try:
scheduler_info_repr = self.scheduler_info._repr_html_()
except AttributeError:
scheduler_info_repr = "Scheduler not started yet."
return get_template("cluster.html.j2").render(
type=type(self).__name__,
name=self.name,
workers=self.scheduler_info["workers"],
dashboard_link=self.dashboard_link,
scheduler_info_repr=scheduler_info_repr,
cluster_status=cluster_status,
)
def _ipython_display_(self, **kwargs):
"""Display the cluster rich IPython repr"""
# Note: it would be simpler to just implement _repr_mimebundle_,
# but we cannot do that until we drop ipywidgets 7 support, as
# it does not provide a public way to get the mimebundle for a
# widget. So instead we fall back on the more customizable _ipython_display_
# and display as a side-effect.
from IPython.display import display
widget = self._widget()
if widget:
import ipywidgets
if parse_version(ipywidgets.__version__) >= parse_version("8.0.0"):
mimebundle = widget._repr_mimebundle_(**kwargs) or {}
mimebundle["text/plain"] = repr(self)
mimebundle["text/html"] = self._repr_html_()
display(mimebundle, raw=True)
else:
display(widget, **kwargs)
else:
mimebundle = {"text/plain": repr(self), "text/html": self._repr_html_()}
display(mimebundle, raw=True)
def __enter__(self):
if self.asynchronous:
raise TypeError(
"Used 'with' with asynchronous class; please use 'async with'"
)
return self.sync(self.__aenter__)
def __exit__(self, exc_type, exc_value, traceback):
aw = self.close()
assert aw is None, aw
def __await__(self):
return self
yield
async def __aenter__(self):
await self
return self
async def __aexit__(self, exc_type, exc_value, traceback):
await self._close()
@property
def scheduler_address(self) -> str:
if not self.scheduler_comm:
return "<Not Connected>"
return self.scheduler_comm.address
@property
def _cluster_class_name(self):
return getattr(self, "_name", type(self).__name__)
def __repr__(self):
text = "%s(%s, %r, workers=%d, threads=%d" % (
self._cluster_class_name,
self.name,
self.scheduler_address,
len(self.scheduler_info["workers"]),
sum(w["nthreads"] for w in self.scheduler_info["workers"].values()),
)
memory = [w["memory_limit"] for w in self.scheduler_info["workers"].values()]
if all(memory):
text += ", memory=" + format_bytes(sum(memory))
text += ")"
return text
@property
def plan(self):
return set(self.workers)
@property
def requested(self):
return set(self.workers)
@property
def observed(self):
return {d["name"] for d in self.scheduler_info["workers"].values()}
def __eq__(self, other):
return type(other) == type(self) and self.name == other.name
def __hash__(self):
return id(self)
async def _wait_for_workers(self, n_workers=0, timeout=None):
self.scheduler_info = SchedulerInfo(await self.scheduler_comm.identity())
if timeout:
deadline = time() + parse_timedelta(timeout)
else:
deadline = None
def running_workers(info):
return len(
[
ws
for ws in info["workers"].values()
if ws["status"] == Status.running.name
]
)
while n_workers and running_workers(self.scheduler_info) < n_workers:
if deadline and time() > deadline:
raise TimeoutError(
"Only %d/%d workers arrived after %s"
% (running_workers(self.scheduler_info), n_workers, timeout)
)
await asyncio.sleep(0.1)
self.scheduler_info = SchedulerInfo(await self.scheduler_comm.identity())
[docs]
def wait_for_workers(self, n_workers: int, timeout: float | None = None) -> None:
"""Blocking call to wait for n workers before continuing
Parameters
----------
n_workers : int
The number of workers
timeout : number, optional
Time in seconds after which to raise a
``dask.distributed.TimeoutError``
"""
if not isinstance(n_workers, int) or n_workers < 1:
raise ValueError(
f"`n_workers` must be a positive integer. Instead got {n_workers}."
)
return self.sync(self._wait_for_workers, n_workers, timeout=timeout)
def _exponential_backoff(
attempt: int, multiplier: float, exponential_base: float, max_interval: float
) -> float:
"""Calculate the duration of an exponential backoff"""
try:
interval = multiplier * exponential_base**attempt
except OverflowError:
return max_interval
return min(max_interval, interval)