from __future__ import annotations
import asyncio
import datetime
import json
import logging
import time
import weakref
from collections import namedtuple
from typing import (
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
NoReturn,
Set,
Tuple,
Union,
overload,
)
import dask.config
import dask.distributed
from aiohttp import ContentTypeError
from dask.utils import parse_timedelta
from distributed.utils import Log, Logs
from rich.progress import Progress
from typing_extensions import TypeAlias
from coiled.cli.setup.entry import do_setup_wizard
from coiled.context import track_context
from coiled.core import Async, IsAsynchronous, Sync, delete_docstring, list_docstring
from coiled.core import Cloud as OldCloud
from coiled.errors import ClusterCreationError, DoesNotExist, ServerError
from coiled.exceptions import PermissionsError
from coiled.types import (
ArchitectureTypesEnum,
AWSOptions,
GCPOptions,
PackageLevel,
PackageSchema,
ResolvedPackageInfo,
SoftwareEnvironmentAlias,
)
from coiled.utils import (
COILED_LOGGER_NAME,
GatewaySecurity,
get_grafana_url,
validate_backend_options,
)
from .states import (
InstanceStateEnum,
ProcessStateEnum,
flatten_log_states,
get_process_instance_state,
log_states,
)
from .widgets.util import simple_progress
logger = logging.getLogger(COILED_LOGGER_NAME)
def setup_logging(level=logging.INFO):
# We want to be able to give info-level messages to users.
# For users who haven't set up a log handler, this requires creating one (b/c the handler of "last resort,
# logging.lastResort, has a level of "warning".
#
# Conservatively, we only do anything here if the user hasn't set up any log handlers on the root logger
# or the Coiled logger. If they have any handler, we assume logging is configured how they want it.
coiled_logger = logging.getLogger(COILED_LOGGER_NAME)
root_logger = logging.getLogger()
if coiled_logger.handlers == [] and root_logger.handlers == []:
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(fmt="[%(asctime)s][%(levelname)-8s][%(name)s] %(message)s"))
# Conservatively, only change the Coiled logger level there's no log level specified yet.
if coiled_logger.level == 0:
coiled_logger.setLevel(level)
coiled_logger.addHandler(stream_handler)
async def handle_api_exception(response, exception_cls=ServerError) -> NoReturn:
try:
error_body = await response.json()
except ContentTypeError:
raise exception_cls(
f"Unexpected status code ({response.status}) to {response.method}:{response.url}, contact support@coiled.io"
) from None
if error_body.get("code") == PermissionsError.code:
exception_cls = PermissionsError
if "message" in error_body:
raise exception_cls(error_body["message"])
if "detail" in error_body:
raise exception_cls(error_body["detail"])
raise exception_cls(error_body)
CloudV2SyncAsync: TypeAlias = Union["CloudV2[Async]", "CloudV2[Sync]"]
class CloudV2(OldCloud, Generic[IsAsynchronous]):
_recent_sync: List[weakref.ReferenceType[CloudV2[Sync]]] = list()
_recent_async: List[weakref.ReferenceType[CloudV2[Async]]] = list()
# just overriding to get the right signature (CloudV2, not Cloud)
def __enter__(self: CloudV2[Sync]) -> CloudV2[Sync]:
return self
def __exit__(self: CloudV2[Sync], typ, value, tb) -> None:
self.close()
async def __aenter__(self: CloudV2[Async]) -> CloudV2[Async]:
return await self._start()
async def __aexit__(self: CloudV2[Async], typ, value, tb) -> None:
await self._close()
# these overloads are necessary for the typechecker to know that we really have a CloudV2, not a Cloud
# without them, CloudV2.current would be typed to return a Cloud
#
# https://www.python.org/dev/peps/pep-0673/ would remove the need for this.
# That PEP also mentions a workaround with type vars, which doesn't work for us because type vars aren't
# subscribtable
@overload
@classmethod
def current(cls, asynchronous: Sync) -> CloudV2[Sync]: ...
@overload
@classmethod
def current(cls, asynchronous: Async) -> CloudV2[Async]: ...
@overload
@classmethod
def current(cls, asynchronous: bool) -> CloudV2: ...
@classmethod
def current(cls, asynchronous: bool) -> CloudV2:
recent: List[weakref.ReferenceType[CloudV2]]
if asynchronous:
recent = cls._recent_async
else:
recent = cls._recent_sync
try:
cloud = recent[-1]()
while cloud is None or cloud.status != "running":
recent.pop()
cloud = recent[-1]()
except IndexError:
if asynchronous:
return cls(asynchronous=True)
else:
return cls(asynchronous=False)
else:
return cloud
@track_context
async def _get_default_instance_types(self, provider: str, gpu: bool = False, arch: str = "x86_64") -> List[str]:
if arch not in ("arm64", "x86_64"):
raise ValueError(f"arch '{arch}' is not supported for default instance types")
if provider == "aws":
if arch == "arm64":
if gpu:
return ["g5g.xlarge"] # has NVIDIA T4G
else:
return ["m7g.xlarge", "m6g.xlarge"]
if gpu:
return ["g4dn.xlarge"]
else:
return ["m6i.xlarge", "m5.xlarge"]
elif provider == "gcp":
if arch != "x86_64":
return ["t2a-standard-4"]
if gpu:
# n1-standard-8 with 30GB of memory might be best, but that's big for a default
return ["n1-standard-4"]
else:
return ["e2-standard-4"]
elif provider == "azure":
if arch != "x86_64":
raise ValueError(f"no default instance type for Azure with {arch} architecture")
if gpu:
raise ValueError("no default GPU instance type for Azure")
return ["Standard_D4_v5"]
else:
raise ValueError(f"unexpected provider {provider}; cannot determine default instance types")
async def _list_dask_scheduler_page(
self,
page: int,
workspace: str | None = None,
since: str | None = "7 days",
user: str | None = None,
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
kwargs = {}
if since:
kwargs["since"] = parse_timedelta(since)
if user:
kwargs["user"] = user
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{workspace}/clusters/list",
params={
"limit": page_size,
"offset": page_size * page,
**kwargs,
},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@track_context
async def _list_dask_scheduler(
self,
workspace: str | None = None,
since: str | None = "7 days",
user: str | None = None,
):
return await self._depaginate_list(
self._list_dask_scheduler_page,
workspace=workspace,
since=since,
user=user,
)
@overload
def list_dask_scheduler(
self: Cloud[Sync],
account: str | None = None,
workspace: str | None = None,
since: str | None = "7 days",
user: str | None = None,
) -> list: ...
@overload
def list_dask_scheduler(
self: Cloud[Async],
account: str | None = None,
workspace: str | None = None,
since: str | None = "7 days",
user: str | None = "",
) -> Awaitable[list]: ...
def list_dask_scheduler(
self,
account: str | None = None,
workspace: str | None = None,
since: str | None = "7 days",
user: str | None = "",
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_dask_scheduler, workspace or account, since=since, user=user)
async def _list_computations(
self, cluster_id: int | None = None, scheduler_id: int | None = None, workspace: str | None = None
):
return await self._depaginate_list(
self._list_computations_page, cluster_id=cluster_id, scheduler_id=scheduler_id, workspace=workspace
)
async def _list_computations_page(
self,
page: int,
cluster_id: int | None = None,
scheduler_id: int | None = None,
workspace: str | None = None,
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
if not scheduler_id and not cluster_id:
raise ValueError("either cluster_id or scheduler_id must be specified")
api = (
f"/api/v2/analytics/{workspace}/{scheduler_id}/computations/list"
if scheduler_id
else f"/api/v2/analytics/{workspace}/cluster/{cluster_id}/computations/list"
)
response = await self._do_request(
"GET",
self.server + api,
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@overload
def list_computations(
self: Cloud[Sync],
cluster_id: int | None = None,
scheduler_id: int | None = None,
account: str | None = None,
workspace: str | None = None,
) -> list: ...
@overload
def list_computations(
self: Cloud[Async],
cluster_id: int | None = None,
scheduler_id: int | None = None,
account: str | None = None,
workspace: str | None = None,
) -> Awaitable[list]: ...
def list_computations(
self,
cluster_id: int | None = None,
scheduler_id: int | None = None,
account: str | None = None,
workspace: str | None = None,
) -> Union[list, Awaitable[list]]:
return self._sync(
self._list_computations, cluster_id=cluster_id, scheduler_id=scheduler_id, workspace=workspace or account
)
def list_exceptions(
self,
cluster_id: int | None = None,
scheduler_id: int | None = None,
account: str | None = None,
workspace: str | None = None,
since: str | None = None,
user: str | None = None,
) -> Union[list, Awaitable[list]]:
return self._sync(
self._list_exceptions,
cluster_id=cluster_id,
scheduler_id=scheduler_id,
workspace=workspace or account,
since=since,
user=user,
)
async def _list_exceptions(
self,
cluster_id: int | None = None,
scheduler_id: int | None = None,
workspace: str | None = None,
since: str | None = None,
user: str | None = None,
):
return await self._depaginate_list(
self._list_exceptions_page,
cluster_id=cluster_id,
scheduler_id=scheduler_id,
workspace=workspace,
since=since,
user=user,
)
async def _list_exceptions_page(
self,
page: int,
cluster_id: int | None = None,
scheduler_id: int | None = None,
workspace: str | None = None,
since: str | None = None,
user: str | None = None,
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
kwargs = {}
if since:
kwargs["since"] = parse_timedelta(since)
if user:
kwargs["user"] = user
if cluster_id:
kwargs["cluster"] = cluster_id
if scheduler_id:
kwargs["scheduler"] = scheduler_id
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{workspace}/exceptions/list",
params={"limit": page_size, "offset": page_size * page, **kwargs},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
async def _list_events_page(
self,
page: int,
cluster_id: int,
workspace: str | None = None,
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{workspace}/{cluster_id}/events/list",
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
async def _list_events(self, cluster_id: int, workspace: str | None = None):
return await self._depaginate_list(self._list_events_page, cluster_id=cluster_id, workspace=workspace)
def list_events(
self,
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_events, cluster_id, workspace or account)
async def _send_state(self, cluster_id: int, desired_status: str, workspace: str | None = None):
workspace = workspace or self.default_workspace
response = await self._do_request(
"POST",
self.server + f"/api/v2/analytics/{workspace}/{cluster_id}/desired-state",
json={"desired_status": desired_status},
)
if response.status >= 400:
await handle_api_exception(response)
def send_state(
self,
cluster_id: int,
desired_status: str,
account: str | None = None,
workspace: str | None = None,
) -> Union[None, Awaitable[None]]:
return self._sync(self._send_state, cluster_id, desired_status, workspace or account)
@track_context
async def _list_clusters(self, workspace: str | None = None, max_pages: int | None = None, just_mine: bool = False):
return await self._depaginate_list(
self._list_clusters_page, workspace=workspace, max_pages=max_pages, just_mine=just_mine
)
@overload
def list_clusters(
self: Cloud[Sync],
account: str | None = None,
workspace: str | None = None,
max_pages: int | None = None,
just_mine: bool = False,
) -> list: ...
@overload
def list_clusters(
self: Cloud[Async],
account: str | None = None,
workspace: str | None = None,
max_pages: int | None = None,
just_mine: bool = False,
) -> Awaitable[list]: ...
@list_docstring
def list_clusters(
self,
account: str | None = None,
workspace: str | None = None,
max_pages: int | None = None,
just_mine: bool = False,
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_clusters, workspace=workspace or account, max_pages=max_pages, just_mine=just_mine)
async def _list_clusters_page(
self, page: int, workspace: str | None = None, just_mine: bool = False
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/",
params={"limit": page_size, "offset": page_size * page, "just_mine": "1" if just_mine else "0"},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@staticmethod
async def _depaginate_list(
func: Callable[..., Awaitable[Tuple[list, bool]]],
max_pages: int | None = None,
*args,
**kwargs,
) -> list:
results_all = []
page = 0
while True:
kwargs["page"] = page
results, next = await func(*args, **kwargs)
results_all += results
page += 1
if (not results) or next is None:
break
# page is the number of pages we've already fetched (since 0-indexed)
if max_pages and page >= max_pages:
break
return results_all
@track_context
async def _create_package_sync_env(
self,
packages: List[ResolvedPackageInfo],
progress: Progress | None = None,
workspace: str | None = None,
gpu_enabled: bool = False,
architecture: ArchitectureTypesEnum = ArchitectureTypesEnum.X86_64,
region_name: str | None = None,
use_uv_installer: bool = True,
) -> SoftwareEnvironmentAlias:
workspace = workspace or self.default_workspace
prepared_packages: List[PackageSchema] = []
for pkg in packages:
if pkg["sdist"] and pkg["md5"]:
with simple_progress(f"Uploading {pkg['name']}", progress=progress):
file_id = await self._create_senv_package(
pkg["sdist"],
contents_md5=pkg["md5"],
workspace=workspace,
region_name=region_name,
)
else:
file_id = None
prepared_packages.append({
"name": pkg["name"],
"source": pkg["source"],
"channel": pkg["channel"],
"conda_name": pkg["conda_name"],
"specifier": pkg["specifier"],
"include": pkg["include"],
"client_version": pkg["client_version"],
"file": file_id,
})
with simple_progress("Requesting package sync build", progress=progress):
result = await self._create_software_environment_v2(
senv={
"packages": prepared_packages,
"raw_pip": None,
"raw_conda": None,
},
workspace=workspace,
architecture=architecture,
gpu_enabled=gpu_enabled,
region_name=region_name,
use_uv_installer=use_uv_installer,
)
return result
@track_context
async def _create_custom_certificate(self, subdomain: str, workspace: str | None = None):
workspace = workspace or self.default_workspace
response = await self._do_request(
"POST",
self.server + f"/api/v2/clusters/account/{workspace}/https-certificate",
json={"subdomain": subdomain},
)
if response.status >= 400:
await handle_api_exception(response)
async def _check_custom_certificate(self, subdomain: str, workspace: str | None = None):
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/https-certificate/{subdomain}",
)
if response.status >= 400:
await handle_api_exception(response)
response_json = await response.json()
cert_status = response_json.get("status")
return cert_status
@track_context
async def _create_cluster(
self,
# todo: make name optional and pick one for them, like pre-declarative?
# https://gitlab.com/coiled/cloud/-/issues/4305
name: str,
*,
software_environment: str | None = None,
senv_v2_id: int | None = None,
worker_class: str | None = None,
worker_options: dict | None = None,
scheduler_options: dict | None = None,
workspace: str | None = None,
workers: int = 0,
environ: Dict | None = None,
tags: Dict | None = None,
dask_config: Dict | None = None,
scheduler_vm_types: list | None = None,
gcp_worker_gpu_type: str | None = None,
gcp_worker_gpu_count: int | None = None,
worker_vm_types: list | None = None,
worker_disk_size: int | None = None,
worker_disk_throughput: int | None = None,
scheduler_disk_size: int | None = None,
backend_options: Union[AWSOptions, GCPOptions, dict] | None = None,
use_scheduler_public_ip: bool | None = None,
use_dashboard_https: bool | None = None,
private_to_creator: bool | None = None,
extra_worker_on_scheduler: bool | None = None,
n_worker_specs_per_host: int | None = None,
custom_subdomain: str | None = None,
batch_job_ids: List[int] | None = None,
extra_user_container: str | None = None,
host_setup_script_content: str | None = None,
) -> Tuple[int, bool]:
# TODO (Declarative): support these args, or decide not to
# https://gitlab.com/coiled/cloud/-/issues/4305
workspace = workspace or self.default_workspace
account, name = self._normalize_name(name, context_workspace=workspace, allow_uppercase=True)
await self._verify_workspace(account)
data = {
"name": name,
"workers": workers,
"worker_instance_types": worker_vm_types,
"scheduler_instance_types": scheduler_vm_types,
"worker_options": worker_options,
"worker_class": worker_class,
"worker_disk_size": worker_disk_size,
"worker_disk_throughput": worker_disk_throughput,
"scheduler_disk_size": scheduler_disk_size,
"scheduler_options": scheduler_options,
"environ": environ,
"tags": tags,
"dask_config": dask_config,
"private_to_creator": private_to_creator,
"env_id": senv_v2_id,
"env_name": software_environment,
"extra_worker_on_scheduler": extra_worker_on_scheduler,
"n_worker_specs_per_host": n_worker_specs_per_host,
"use_aws_creds_endpoint": dask.config.get("coiled.use_aws_creds_endpoint", False),
"custom_subdomain": custom_subdomain,
"batch_job_ids": batch_job_ids,
"extra_user_container": extra_user_container,
"host_setup_script": host_setup_script_content,
}
backend_options = backend_options if backend_options else {}
if gcp_worker_gpu_type is not None:
# for backwards compatibility with v1 options
backend_options = {
**backend_options,
"worker_accelerator_count": gcp_worker_gpu_count or 1,
"worker_accelerator_type": gcp_worker_gpu_type,
}
elif gcp_worker_gpu_count:
# not ideal but v1 only supported T4 and `worker_gpu=1` would give you one
backend_options = {
**backend_options,
"worker_accelerator_count": gcp_worker_gpu_count,
"worker_accelerator_type": "nvidia-tesla-t4",
}
if use_scheduler_public_ip is False:
if "use_dashboard_public_ip" not in backend_options and not use_dashboard_https:
backend_options["use_dashboard_public_ip"] = False
if use_dashboard_https is False:
if "use_dashboard_https" not in backend_options:
backend_options["use_dashboard_https"] = False
if backend_options:
# for backwards compatibility with v1 options
if "region" in backend_options and "region_name" not in backend_options:
backend_options["region_name"] = backend_options["region"] # type: ignore
del backend_options["region"] # type: ignore
if "zone" in backend_options and "zone_name" not in backend_options:
backend_options["zone_name"] = backend_options["zone"] # type: ignore
del backend_options["zone"] # type: ignore
# firewall just lets you specify a single CIDR block to open for ingress
# we want to support a list of ingress CIDR blocks
if "firewall" in backend_options:
backend_options["ingress"] = [backend_options.pop("firewall")] # type: ignore
# convert the list of ingress rules to the FirewallSpec expected server-side
if "ingress" in backend_options:
fw_spec = {"ingress": backend_options.pop("ingress")}
backend_options["firewall_spec"] = fw_spec # type: ignore
validate_backend_options(backend_options)
data["options"] = backend_options
response = await self._do_request(
"POST",
self.server + f"/api/v2/clusters/account/{account}/",
json=data,
)
response_json = await response.json()
if response.status >= 400:
from .widgets import EXECUTION_CONTEXT
if response_json.get("code") == "NO_CLOUD_SETUP":
server_error_message = response_json.get("message")
error_message = f"{server_error_message} or by running `coiled setup`"
if EXECUTION_CONTEXT == "terminal":
# maybe not interactive so just raise
raise ClusterCreationError(error_message)
else:
# interactive session so let's try running the cloud setup wizard
if await do_setup_wizard():
# the user setup their cloud backend, so let's try creating cluster again!
response = await self._do_request(
"POST",
self.server + f"/api/v2/clusters/account/{account}/",
json=data,
)
if response.status >= 400:
await handle_api_exception(response) # always raises exception, no return
response_json = await response.json()
else:
raise ClusterCreationError(error_message)
else:
error_class = PermissionsError if response_json.get("code") == PermissionsError.code else ServerError
if "message" in response_json:
raise error_class(response_json["message"])
if "detail" in response_json:
raise error_class(response_json["detail"])
raise error_class(response_json)
return response_json["id"], response_json["existing"]
@overload
def create_cluster(
self: Cloud[Sync],
name: str,
*,
software: str | None = None,
worker_class: str | None = None,
worker_options: dict | None = None,
scheduler_options: dict | None = None,
account: str | None = None,
workspace: str | None = None,
workers: int = 0,
environ: Dict | None = None,
tags: Dict | None = None,
dask_config: Dict | None = None,
private_to_creator: bool | None = None,
scheduler_vm_types: list | None = None,
worker_gpu_type: str | None = None,
worker_vm_types: list | None = None,
worker_disk_size: int | None = None,
worker_disk_throughput: int | None = None,
scheduler_disk_size: int | None = None,
backend_options: Union[dict, AWSOptions, GCPOptions] | None = None,
) -> Tuple[int, bool]: ...
@overload
def create_cluster(
self: Cloud[Async],
name: str,
*,
software: str | None = None,
worker_class: str | None = None,
worker_options: dict | None = None,
scheduler_options: dict | None = None,
account: str | None = None,
workspace: str | None = None,
workers: int = 0,
environ: Dict | None = None,
tags: Dict | None = None,
dask_config: Dict | None = None,
private_to_creator: bool | None = None,
scheduler_vm_types: list | None = None,
worker_gpu_type: str | None = None,
worker_vm_types: list | None = None,
worker_disk_size: int | None = None,
worker_disk_throughput: int | None = None,
scheduler_disk_size: int | None = None,
backend_options: Union[dict, AWSOptions, GCPOptions] | None = None,
) -> Awaitable[Tuple[int, bool]]: ...
def create_cluster(
self,
name: str,
*,
software: str | None = None,
worker_class: str | None = None,
worker_options: dict | None = None,
scheduler_options: dict | None = None,
account: str | None = None,
workspace: str | None = None,
workers: int = 0,
environ: Dict | None = None,
tags: Dict | None = None,
private_to_creator: bool | None = None,
dask_config: Dict | None = None,
scheduler_vm_types: list | None = None,
worker_gpu_type: str | None = None,
worker_vm_types: list | None = None,
worker_disk_size: int | None = None,
worker_disk_throughput: int | None = None,
scheduler_disk_size: int | None = None,
backend_options: Union[dict, AWSOptions, GCPOptions] | None = None,
) -> Union[Tuple[int, bool], Awaitable[Tuple[int, bool]]]:
return self._sync(
self._create_cluster,
name=name,
software_environment=software,
worker_class=worker_class,
worker_options=worker_options,
scheduler_options=scheduler_options,
workspace=workspace or account,
workers=workers,
environ=environ,
tags=tags,
dask_config=dask_config,
private_to_creator=private_to_creator,
scheduler_vm_types=scheduler_vm_types,
worker_vm_types=worker_vm_types,
gcp_worker_gpu_type=worker_gpu_type,
worker_disk_size=worker_disk_size,
worker_disk_throughput=worker_disk_throughput,
scheduler_disk_size=scheduler_disk_size,
backend_options=backend_options,
)
@track_context
async def _delete_cluster(
self, cluster_id: int, workspace: str | None = None, reason: str | None = None, pause: bool = False
) -> None:
workspace = workspace or self.default_workspace
route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}"
params = {}
if reason:
params["reason"] = reason[:6000] # reason is sometimes long, we need to keep URL length under 8192 bytes
if pause:
params["pause"] = 1
if not params:
params = None
response = await self._do_request_idempotent(
"DELETE",
self.server + route,
params=params,
)
if response.status >= 400:
await handle_api_exception(response)
else:
# multiple deletes sometimes fail if we don't await response here
await response.json()
logger.info(f"Cluster {cluster_id} deleted successfully.")
@overload
def delete_cluster(
self: Cloud[Sync],
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
reason: str | None = None,
pause: bool = False,
) -> None: ...
@overload
def delete_cluster(
self: Cloud[Async],
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
reason: str | None = None,
pause: bool = False,
) -> Awaitable[None]: ...
@delete_docstring # TODO: this docstring erroneously says "Name of cluster" when it really accepts an ID
def delete_cluster(
self,
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
reason: str | None = None,
pause: bool = False,
) -> Awaitable[None] | None:
return self._sync(
self._delete_cluster, cluster_id=cluster_id, workspace=workspace or account, reason=reason, pause=pause
)
async def _get_cluster_state(self, cluster_id: int, workspace: str | None = None) -> dict:
workspace = workspace or self.default_workspace
# Make request directly instead of using `_do_request` because we don't want any retries.
# Retry logic doesn't make sense because this is called by (frequent) period callback, so we'll just wait
# for next periodic callback call, otherwise retries will overlap with the periodic callback and build up.
session = self._ensure_session()
response = await session.request(
"GET", self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/state"
)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
async def _get_cluster_details(self, cluster_id: int, workspace: str | None = None):
workspace = workspace or self.default_workspace
r = await self._do_request_idempotent(
"GET", self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}"
)
if r.status >= 400:
await handle_api_exception(r)
return await r.json()
def _get_cluster_details_synced(self, cluster_id: int, workspace: str | None = None):
return self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
workspace=workspace,
)
def _cluster_grafana_url(self, cluster_id: int, workspace: str | None = None):
"""for internal Coiled use"""
workspace = workspace or self.default_workspace
details = self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
workspace=workspace,
)
return get_grafana_url(details, account=workspace, cluster_id=cluster_id)
def cluster_details(self, cluster_id: int, account: str | None = None, workspace: str | None = None):
details = self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
workspace=workspace or account,
)
state_keys = ["state", "reason", "updated"]
def get_state(state: dict):
return {k: v for k, v in state.items() if k in state_keys}
def get_instance(instance):
if instance is None:
return None
else:
return {
"id": instance["id"],
"created": instance["created"],
"name": instance["name"],
"public_ip_address": instance["public_ip_address"],
"private_ip_address": instance["private_ip_address"],
"current_state": get_state(instance["current_state"]),
}
def get_process(process: dict):
if process is None:
return None
else:
return {
"created": process["created"],
"name": process["name"],
"current_state": get_state(process["current_state"]),
"instance": get_instance(process["instance"]),
}
return {
"id": details["id"],
"name": details["name"],
"workers": [get_process(w) for w in details["workers"]],
"scheduler": get_process(details["scheduler"]),
"current_state": get_state(details["current_state"]),
"created": details["created"],
}
async def _get_workers_page(self, cluster_id: int, page: int, workspace: str | None = None) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
response = await self._do_request(
"GET",
self.server + f"/api/v2/workers/account/{workspace}/cluster/{cluster_id}/",
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@track_context
async def _get_worker_names(
self,
workspace: str,
cluster_id: int,
statuses: List[ProcessStateEnum] | None = None,
) -> Set[str]:
worker_infos = await self._depaginate_list(self._get_workers_page, cluster_id=cluster_id, workspace=workspace)
logger.debug(f"workers: {worker_infos}")
return {w["name"] for w in worker_infos if statuses is None or w["current_state"]["state"] in statuses}
@track_context
async def _security(self, cluster_id: int, workspace: str | None = None, client_wants_public_ip: bool = True):
cluster_info = await self._get_cluster_details(cluster_id=cluster_id, workspace=workspace)
if ProcessStateEnum(cluster_info["scheduler"]["current_state"]["state"]) != ProcessStateEnum.started:
scheduler_state = cluster_info["scheduler"]["current_state"]["state"]
raise RuntimeError(
f"Cannot get security info for cluster {cluster_id}, scheduler state is {scheduler_state}"
)
public_ip = cluster_info["scheduler"]["instance"]["public_ip_address"]
private_ip = cluster_info["scheduler"]["instance"]["private_ip_address"]
tls_cert = cluster_info["cluster_options"]["tls_cert"]
tls_key = cluster_info["cluster_options"]["tls_key"]
scheduler_port = cluster_info["scheduler_port"]
dashboard_address = cluster_info["scheduler"]["dashboard_address"]
give_scheduler_public_ip = cluster_info["cluster_infra"]["give_scheduler_public_ip"]
private_address = f"tls://{private_ip}:{scheduler_port}"
public_address = f"tls://{public_ip}:{scheduler_port}"
use_public_address = give_scheduler_public_ip and client_wants_public_ip
if use_public_address:
if not public_ip:
raise RuntimeError(
"Your Coiled client is configured to use the public IP address, but the scheduler VM does not "
"have a public IP address.\n\n"
"If you're expecting to connect on private IP address, you can run\n"
" coiled config set coiled.use_scheduler_public_ip False\n"
"to configure your local Client to use the private IP address, "
"or contact support@coiled.io if you'd like help."
)
address_to_use = public_address
else:
address_to_use = private_address
logger.info(f"Connecting to scheduler on its internal address: {address_to_use}")
# TODO (Declarative): pass extra_conn_args if we care about proxying through Coiled to the scheduler
security = GatewaySecurity(tls_key, tls_cert)
return security, {
"address_to_use": address_to_use,
"private_address": private_address,
"public_address": public_address,
"dashboard_address": dashboard_address,
}
@track_context
async def _requested_workers(self, cluster_id: int, account: str | None = None) -> Set[str]:
raise NotImplementedError("TODO")
@overload
def requested_workers(self: Cloud[Sync], cluster_id: int, account: str | None = None) -> Set[str]: ...
@track_context
async def _get_cluster_by_name(self, name: str, workspace: str | None = None, include_paused: bool = False) -> int:
workspace, name = self._normalize_name(
name, context_workspace=workspace, allow_uppercase=True, property_name="cluster name"
)
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/name/{name}",
params={"include_paused": int(include_paused)},
)
if response.status == 404:
raise DoesNotExist
elif response.status >= 400:
await handle_api_exception(response)
cluster = await response.json()
return cluster["id"]
@overload
def get_cluster_by_name(
self: Cloud[Sync],
name: str,
account: str | None = None,
workspace: str | None = None,
include_paused: bool = False,
) -> int: ...
@overload
def get_cluster_by_name(
self: Cloud[Async],
name: str,
account: str | None = None,
workspace: str | None = None,
include_paused: bool = False,
) -> Awaitable[int]: ...
def get_cluster_by_name(
self,
name: str,
account: str | None = None,
workspace: str | None = None,
include_paused: bool = False,
) -> Union[int, Awaitable[int]]:
return self._sync(
self._get_cluster_by_name,
name=name,
workspace=workspace or account,
include_paused=include_paused,
)
@track_context
async def _cluster_status(
self,
cluster_id: int,
account: str | None = None,
exclude_stopped: bool = True,
) -> dict:
raise NotImplementedError("TODO?")
@track_context
async def _get_cluster_states_declarative(
self,
cluster_id: int,
workspace: str | None = None,
start_time: datetime.datetime | None = None,
) -> dict:
workspace = workspace or self.default_workspace
params = {"start_time": start_time.isoformat()} if start_time is not None else {}
response = await self._do_request_idempotent(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/states",
params=params,
)
# if we get 403 on this endpoint, most likely it's temporary,
# unless we've never gotten 403 or it's been too long since we got a good response from the endpoint
if (
response.status == 403
and time.monotonic() - getattr(self, "_get_cluster_states_declarative_last_good_response", 0) < 60
):
return {}
elif response.status >= 400:
await handle_api_exception(response)
self._get_cluster_states_declarative_last_good_response = time.monotonic()
return await response.json()
def get_cluster_states(
self,
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
start_time: datetime.datetime | None = None,
) -> Union[dict, Awaitable[dict]]:
return self._sync(
self._get_cluster_states_declarative,
cluster_id=cluster_id,
workspace=workspace or account,
start_time=start_time,
)
def get_clusters_by_name(
self,
name: str,
account: str | None = None,
workspace: str | None = None,
) -> List[dict]:
"""Get all clusters matching name."""
return self._sync(
self._get_clusters_by_name,
name=name,
workspace=workspace or account,
)
@track_context
async def _get_clusters_by_name(self, name: str, workspace: str | None = None) -> List[dict]:
workspace, name = self._normalize_name(name, context_workspace=workspace, allow_uppercase=True)
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/",
params={"name": name},
)
if response.status == 404:
raise DoesNotExist
elif response.status >= 400:
await handle_api_exception(response)
cluster = await response.json()
return cluster
@overload
def cluster_logs(
self: Cloud[Sync],
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Logs: ...
@overload
def cluster_logs(
self: Cloud[Async],
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Awaitable[Logs]: ...
@track_context
async def _cluster_logs(
self,
cluster_id: int,
workspace: str | None = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Logs:
def is_errored(process):
process_state, instance_state = get_process_instance_state(process)
return process_state == ProcessStateEnum.error or instance_state == InstanceStateEnum.error
workspace = workspace or self.default_workspace
# hits endpoint in order to get scheduler and worker instance names
cluster_info = await self._get_cluster_details(cluster_id=cluster_id, workspace=workspace)
try:
scheduler_name = cluster_info["scheduler"]["instance"]["name"]
except (TypeError, KeyError):
# no scheduler instance name in cluster info
logger.warning("No scheduler found when attempting to retrieve cluster logs.")
scheduler_name = None
worker_names = [
worker["instance"]["name"]
for worker in cluster_info["workers"]
if worker["instance"] and (not errors_only or is_errored(worker))
]
LabeledInstance = namedtuple("LabeledInstance", ("name", "label"))
instances = []
if scheduler and scheduler_name and (not errors_only or is_errored(cluster_info["scheduler"])):
instances.append(LabeledInstance(scheduler_name, "Scheduler"))
if workers and worker_names:
instances.extend([LabeledInstance(worker_name, worker_name) for worker_name in worker_names])
async def instance_log_with_semaphor(semaphor, **kwargs):
async with semaphor:
return await self._instance_logs(**kwargs)
# only get 100 logs at a time; the limit here is redundant since aiohttp session already limits concurrent
# connections but let's be safe just in case
semaphor = asyncio.Semaphore(value=100)
results = await asyncio.gather(*[
instance_log_with_semaphor(semaphor=semaphor, workspace=workspace, instance_name=inst.name)
for inst in instances
])
out = {
instance_label: instance_log
for (_, instance_label), instance_log in zip(instances, results)
if len(instance_log)
}
return Logs(out)
def cluster_logs(
self,
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Union[Logs, Awaitable[Logs]]:
return self._sync(
self._cluster_logs,
cluster_id=cluster_id,
workspace=workspace or account,
scheduler=scheduler,
workers=workers,
errors_only=errors_only,
)
async def _instance_logs(self, workspace: str, instance_name: str, safe=True) -> Log:
response = await self._do_request(
"GET",
self.server + "/api/v2/instances/{}/instance/{}/logs".format(workspace, instance_name),
)
if response.status >= 400:
if safe:
logger.warning(f"Error retrieving logs for {instance_name}")
return Log()
await handle_api_exception(response)
data = await response.json()
messages = "\n".join(logline.get("message", "") for logline in data)
return Log(messages)
@overload
def requested_workers(
self: Cloud[Async], cluster_id: int, account: str | None = None, workspace: str | None = None
) -> Awaitable[Set[str]]: ...
def requested_workers(
self, cluster_id: int, account: str | None = None, workspace: str | None = None
) -> Union[
Set[str],
Awaitable[Set[str]],
]:
return self._sync(self._requested_workers, cluster_id, workspace or account)
@overload
def scale_up(
self: Cloud[Sync], cluster_id: int, n: int, account: str | None = None, workspace: str | None = None
) -> Dict | None: ...
@overload
def scale_up(
self: Cloud[Async], cluster_id: int, n: int, account: str | None = None, workspace: str | None = None
) -> Awaitable[Dict | None]: ...
def scale_up(
self, cluster_id: int, n: int, account: str | None = None, workspace: str | None = None
) -> Union[Dict | None, Awaitable[Dict | None]]:
"""Scale cluster to ``n`` workers
Parameters
----------
cluster_id
Unique cluster identifier.
n
Number of workers to scale cluster size to.
account
**DEPRECATED**. Use ``workspace`` instead.
workspace
The Coiled workspace (previously "account") to use. If not specified,
will check the ``coiled.workspace`` or ``coiled.account`` configuration values,
or will use your default workspace if those aren't set.
"""
return self._sync(self._scale_up, cluster_id, n, workspace or account)
@overload
def scale_down(
self: Cloud[Sync],
cluster_id: int,
workers: Set[str],
account: str | None = None,
workspace: str | None = None,
) -> None: ...
@overload
def scale_down(
self: Cloud[Async],
cluster_id: int,
workers: Set[str],
account: str | None = None,
workspace: str | None = None,
) -> Awaitable[None]: ...
def scale_down(
self,
cluster_id: int,
workers: Set[str],
account: str | None = None,
workspace: str | None = None,
) -> Awaitable[None] | None:
"""Scale cluster to ``n`` workers
Parameters
----------
cluster_id
Unique cluster identifier.
workers
Set of workers to scale down to.
account
**DEPRECATED**. Use ``workspace`` instead.
workspace
The Coiled workspace (previously "account") to use. If not specified,
will check the ``coiled.workspace`` or ``coiled.account`` configuration values,
or will use your default workspace if those aren't set.
"""
return self._sync(self._scale_down, cluster_id, workers, workspace or account)
@track_context
async def _better_cluster_logs(
self,
cluster_id: int,
workspace: str | None = None,
instance_ids: List[int] | None = None,
dask: bool = False,
system: bool = False,
since_ms: int | None = None,
until_ms: int | None = None,
filter: str | None = None,
):
workspace = workspace or self.default_workspace
url_params = {}
if dask:
url_params["dask"] = "True"
if system:
url_params["system"] = "True"
if since_ms:
url_params["since_ms"] = f"{since_ms}"
if until_ms:
url_params["until_ms"] = f"{until_ms}"
if filter:
url_params["filter_pattern"] = f"{filter}"
if instance_ids:
url_params["instance_ids"] = ",".join(map(str, instance_ids))
url_path = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/better-logs"
response = await self._do_request(
"GET",
f"{self.server}{url_path}",
params=url_params,
)
if response.status >= 400:
await handle_api_exception(response)
data = await response.json()
return data
def better_cluster_logs(
self,
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
instance_ids: List[int] | None = None,
dask: bool = False,
system: bool = False,
since_ms: int | None = None,
until_ms: int | None = None,
filter: str | None = None,
) -> Logs:
return self._sync(
self._better_cluster_logs,
cluster_id=cluster_id,
workspace=workspace or account,
instance_ids=instance_ids,
dask=dask,
system=system,
since_ms=since_ms,
until_ms=until_ms,
filter=filter,
)
@track_context
async def _scale_up(self, cluster_id: int, n: int, workspace: str | None = None, reason: str | None = None) -> Dict:
"""
Increases the number of workers by ``n``.
"""
workspace = workspace or self.default_workspace
data = {"n_workers": n}
if reason:
# pyright is annoying
data["reason"] = reason # type: ignore
response = await self._do_request(
"POST", f"{self.server}/api/v2/workers/account/{workspace}/cluster/{cluster_id}/", json=data
)
if response.status >= 400:
await handle_api_exception(response)
workers_info = await response.json()
return {"workers": {w["name"] for w in workers_info}}
@track_context
async def _scale_down(
self, cluster_id: int, workers: Iterable[str], workspace: str | None = None, reason: str | None = None
) -> None:
workspace = workspace or self.default_workspace
workers = list(workers) # yarl, used by aiohttp, expects list (not set) of strings
reason_dict = {"reason": reason} if reason else {}
response = await self._do_request(
"DELETE",
f"{self.server}/api/v2/workers/account/{workspace}/cluster/{cluster_id}/",
params={"name": workers, **reason_dict},
)
if response.status >= 400:
await handle_api_exception(response)
@overload
def security(
self: Cloud[Sync], cluster_id: int, account: str | None = None, workspace: str | None = None
) -> Tuple[dask.distributed.Security, dict]: ...
@overload
def security(
self: Cloud[Async], cluster_id: int, account: str | None = None, workspace: str | None = None
) -> Awaitable[Tuple[dask.distributed.Security, dict]]: ...
def security(
self, cluster_id: int, account: str | None = None, workspace: str | None = None
) -> Union[
Tuple[dask.distributed.Security, dict],
Awaitable[Tuple[dask.distributed.Security, dict]],
]:
return self._sync(self._security, cluster_id, workspace or account)
@track_context
async def _fetch_package_levels(self, workspace: str | None = None) -> List[PackageLevel]:
workspace = workspace or self.default_workspace
response = await self._do_request("GET", f"{self.server}/api/v2/packages/", params={"account": workspace})
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def get_ssh_key(
self,
cluster_id: int,
workspace: str | None = None,
worker: str | None = None,
) -> dict:
workspace = workspace or self.default_workspace
return self._sync(
self._get_ssh_key,
cluster_id=cluster_id,
workspace=workspace,
worker=worker,
)
@track_context
async def _get_ssh_key(self, cluster_id: int, workspace: str, worker: str | None) -> dict:
workspace = workspace or self.default_workspace
route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/ssh-key"
url = f"{self.server}{route}"
response = await self._do_request("GET", url, params={"worker": worker} if worker else None)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def get_cluster_log_info(
self,
cluster_id: int,
workspace: str | None = None,
) -> dict:
workspace = workspace or self.default_workspace
return self._sync(
self._get_cluster_log_info,
cluster_id=cluster_id,
workspace=workspace,
)
@track_context
async def _get_cluster_log_info(
self,
cluster_id: int,
workspace: str,
) -> dict:
workspace = workspace or self.default_workspace
route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/log-info"
url = f"{self.server}{route}"
response = await self._do_request("GET", url)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
@track_context
async def _get_cluster_aggregated_metric(
self,
cluster_id: int,
workspace: str | None,
query: str,
over_time: str,
start_ts: int | None,
end_ts: int | None,
):
workspace = workspace or self.default_workspace
route = f"/api/v2/metrics/account/{workspace}/cluster/{cluster_id}/value"
url = f"{self.server}{route}"
params = {"query": query, "over_time": over_time}
if start_ts:
params["start_ts"] = str(start_ts)
if end_ts:
params["end_ts"] = str(end_ts)
response = await self._do_request("GET", url, params=params)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
@track_context
async def _add_cluster_span(self, cluster_id: int, workspace: str | None, span_identifier: str, data: dict):
workspace = workspace or self.default_workspace
route = f"/api/v2/analytics/{workspace}/cluster/{cluster_id}/span/{span_identifier}"
url = f"{self.server}{route}"
response = await self._do_request("POST", url, json=data)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def _sync_request(self, route, method: str = "GET", json_result: bool = False, **kwargs):
url = f"{self.server}{route}"
response = self._sync(self._do_request, url=url, method=method, **kwargs)
if response.status >= 400:
raise ServerError(f"{url} returned {response.status}")
async def get_result(r):
return await (r.json() if json_result else r.text())
return self._sync(
get_result,
response,
)
Cloud = CloudV2
[docs]
def cluster_logs(
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
):
"""
Returns cluster logs as a dictionary, with a key for the scheduler and each worker.
.. versionchanged:: 0.2.0
``cluster_name`` is no longer accepted, use ``cluster_id`` instead.
"""
with CloudV2() as cloud:
return cloud.cluster_logs(
cluster_id=cluster_id,
workspace=workspace or account,
scheduler=scheduler,
workers=workers,
errors_only=errors_only,
)
def better_cluster_logs(
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
instance_ids: List[int] | None = None,
dask: bool = False,
system: bool = False,
since_ms: int | None = None,
until_ms: int | None = None,
filter: str | None = None,
):
"""
Pull logs for the cluster using better endpoint.
Logs for recent clusters are split between system and container (dask), you can get
either or both (or none).
since_ms and until_ms are both inclusive (you'll get logs with timestamp matching those).
"""
with Cloud() as cloud:
return cloud.better_cluster_logs(
cluster_id=cluster_id,
workspace=workspace or account,
instance_ids=instance_ids,
dask=dask,
system=system,
since_ms=since_ms,
until_ms=until_ms,
filter=filter,
)
def cluster_details(
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
) -> dict:
"""
Get details of a cluster as a dictionary.
"""
with CloudV2() as cloud:
return cloud.cluster_details(
cluster_id=cluster_id,
workspace=workspace or account,
)
def log_cluster_debug_info(
cluster_id: int,
account: str | None = None,
workspace: str | None = None,
):
with CloudV2() as cloud:
details = cloud.cluster_details(cluster_id, workspace or account)
logger.debug("Cluster details:")
logger.debug(json.dumps(details, indent=2))
states_by_type = cloud.get_cluster_states(cluster_id, workspace or account)
logger.debug("cluster state history:")
log_states(flatten_log_states(states_by_type), level=logging.DEBUG)
def create_cluster(
name: str,
*,
software: str | None = None,
worker_options: dict | None = None,
scheduler_options: dict | None = None,
account: str | None = None,
workspace: str | None = None,
workers: int = 0,
environ: Dict | None = None,
tags: Dict | None = None,
dask_config: Dict | None = None,
private_to_creator: bool | None = None,
scheduler_vm_types: list | None = None,
worker_vm_types: list | None = None,
worker_disk_size: int | None = None,
worker_disk_throughput: int | None = None,
scheduler_disk_size: int | None = None,
backend_options: Union[dict, AWSOptions, GCPOptions] | None = None,
) -> int:
"""Create a cluster
Parameters
---------
name
Name of cluster.
software
Identifier of the software environment to use, in the format (<account>/)<name>. If the software environment
is owned by the same account as that passed into "account", the (<account>/) prefix is optional.
For example, suppose your account is "wondercorp", but your friends at "friendlycorp" have an environment
named "xgboost" that you want to use; you can specify this with "friendlycorp/xgboost". If you simply
entered "xgboost", this is shorthand for "wondercorp/xgboost".
The "name" portion of (<account>/)<name> can only contain ASCII letters, hyphens and underscores.
worker_options
Mapping with keyword arguments to pass to ``worker_class``. Defaults to ``{}``.
scheduler_options
Mapping with keyword arguments to pass to the Scheduler ``__init__``. Defaults to ``{}``.
account
**DEPRECATED**. Use ``workspace`` instead.
workspace
The Coiled workspace (previously "account") to use. If not specified,
will check the ``coiled.workspace`` or ``coiled.account`` configuration values,
or will use your default workspace if those aren't set.
workers
Number of workers we to launch.
environ
Dictionary of environment variables.
tags
Dictionary of tags. Can also be set using the ``coiled.tags``
Dask configuration option. Tags specified for cluster using keyword argument
take precedence over those from Dask configuration.
dask_config
Dictionary of dask config to put on cluster
See Also
--------
coiled.Cluster
"""
with CloudV2(account=workspace or account) as cloud:
cluster, _existing = cloud.create_cluster(
name=name,
software=software,
worker_options=worker_options,
scheduler_options=scheduler_options,
workspace=workspace or account,
workers=workers,
environ=environ,
tags=tags,
dask_config=dask_config,
private_to_creator=private_to_creator,
backend_options=backend_options,
worker_vm_types=worker_vm_types,
worker_disk_size=worker_disk_size,
worker_disk_throughput=worker_disk_throughput,
scheduler_disk_size=scheduler_disk_size,
scheduler_vm_types=scheduler_vm_types,
)
return cluster
[docs]
@list_docstring
def list_clusters(account=None, workspace=None, max_pages: int | None = None):
with CloudV2() as cloud:
return cloud.list_clusters(workspace=workspace or account, max_pages=max_pages)
[docs]
@delete_docstring
def delete_cluster(name: str, account: str | None = None, workspace: str | None = None, pause: bool = False):
with CloudV2() as cloud:
cluster_id = cloud.get_cluster_by_name(name=name, workspace=workspace or account)
if cluster_id is not None:
return cloud.delete_cluster(cluster_id=cluster_id, workspace=workspace or account, pause=pause)
def create_package_sync_software_env(
workspace=None, gpu=False, arm=False, strict=False, force_rich_widget=False, **kwargs
):
from coiled.capture_environment import scan_and_create
with Cloud(workspace=workspace) as cloud:
package_sync_env_alias = cloud._sync(
scan_and_create,
cloud=cloud,
workspace=workspace,
gpu_enabled=gpu,
architecture=ArchitectureTypesEnum.ARM64 if arm else ArchitectureTypesEnum.X86_64,
package_sync_strict=strict,
force_rich_widget=force_rich_widget,
**kwargs,
)
return package_sync_env_alias
def get_dask_client_from_batch_node():
"""
Get Dask client for a Coiled Batch cluster.
This function can be run on a node of a Coiled Batch cluster,
and returns a dask.distributed.Client object connected to the Dask scheduler running
on the scheduler node of the Batch cluster (which can also be a Dask cluster).
"""
import os
import dask.config
from dask.distributed import Client
scheduler_address = os.environ.get("COILED_INTERNAL_DASK_SCHEDULER_ADDRESS")
if not scheduler_address:
raise ValueError(
"Unable to get scheduler address from COILED_INTERNAL_DASK_SCHEDULER_ADDRESS environment variable, "
"this should automatically be set on Coiled Batch nodes."
)
security_config = {
"scheduler-address": scheduler_address,
"distributed.comm.tls.client.key": "/dask-tls/key",
"distributed.comm.tls.client.cert": "/dask-tls/cert",
"distributed.comm.tls.ca_file": "/dask-tls/cert",
"distributed.comm.require-encryption": True,
}
with dask.config.set(security_config):
return Client()