from __future__ import annotations
import asyncio
import datetime
import json
import logging
import weakref
from collections import namedtuple
from typing import (
Awaitable,
Callable,
Dict,
Generic,
Iterable,
List,
NoReturn,
Optional,
Set,
Tuple,
Union,
overload,
)
import dask.config
import dask.distributed
from aiohttp import ContentTypeError
from dask.utils import parse_timedelta
from distributed.utils import Log, Logs
from rich.progress import Progress
from typing_extensions import TypeAlias
from coiled.cli.setup.entry import do_setup_wizard
from coiled.context import track_context
from coiled.core import Async, IsAsynchronous, Sync, delete_docstring, list_docstring
from coiled.core import Cloud as OldCloud
from coiled.errors import ClusterCreationError, DoesNotExist, ServerError
from coiled.exceptions import PermissionsError
from coiled.types import (
ArchitectureTypesEnum,
AWSOptions,
GCPOptions,
PackageLevel,
PackageSchema,
ResolvedPackageInfo,
SoftwareEnvironmentAlias,
)
from coiled.utils import (
COILED_LOGGER_NAME,
GatewaySecurity,
get_grafana_url,
validate_backend_options,
)
from .states import (
InstanceStateEnum,
ProcessStateEnum,
flatten_log_states,
get_process_instance_state,
log_states,
)
from .widgets.util import simple_progress
logger = logging.getLogger(COILED_LOGGER_NAME)
def setup_logging(level=logging.INFO):
# We want to be able to give info-level messages to users.
# For users who haven't set up a log handler, this requires creating one (b/c the handler of "last resort,
# logging.lastResort, has a level of "warning".
#
# Conservatively, we only do anything here if the user hasn't set up any log handlers on the root logger
# or the Coiled logger. If they have any handler, we assume logging is configured how they want it.
coiled_logger = logging.getLogger(COILED_LOGGER_NAME)
root_logger = logging.getLogger()
if coiled_logger.handlers == [] and root_logger.handlers == []:
stream_handler = logging.StreamHandler()
stream_handler.setFormatter(logging.Formatter(fmt="[%(asctime)s][%(levelname)-8s][%(name)s] %(message)s"))
# Conservatively, only change the Coiled logger level there's no log level specified yet.
if coiled_logger.level == 0:
coiled_logger.setLevel(level)
coiled_logger.addHandler(stream_handler)
async def handle_api_exception(response, exception_cls=ServerError) -> NoReturn:
try:
error_body = await response.json()
except ContentTypeError:
raise exception_cls(
f"Unexpected status code ({response.status}) to {response.method}:{response.url}, contact support@coiled.io"
) from None
if error_body.get("code") == PermissionsError.code:
exception_cls = PermissionsError
if "message" in error_body:
raise exception_cls(error_body["message"])
if "detail" in error_body:
raise exception_cls(error_body["detail"])
raise exception_cls(error_body)
CloudV2SyncAsync: TypeAlias = Union["CloudV2[Async]", "CloudV2[Sync]"]
class CloudV2(OldCloud, Generic[IsAsynchronous]):
_recent_sync: List[weakref.ReferenceType[CloudV2[Sync]]] = list()
_recent_async: List[weakref.ReferenceType[CloudV2[Async]]] = list()
# just overriding to get the right signature (CloudV2, not Cloud)
def __enter__(self: CloudV2[Sync]) -> CloudV2[Sync]:
return self
def __exit__(self: CloudV2[Sync], typ, value, tb) -> None:
self.close()
async def __aenter__(self: CloudV2[Async]) -> CloudV2[Async]:
return await self._start()
async def __aexit__(self: CloudV2[Async], typ, value, tb) -> None:
await self._close()
# these overloads are necessary for the typechecker to know that we really have a CloudV2, not a Cloud
# without them, CloudV2.current would be typed to return a Cloud
#
# https://www.python.org/dev/peps/pep-0673/ would remove the need for this.
# That PEP also mentions a workaround with type vars, which doesn't work for us because type vars aren't
# subscribtable
@overload
@classmethod
def current(cls, asynchronous: Sync) -> CloudV2[Sync]: ...
@overload
@classmethod
def current(cls, asynchronous: Async) -> CloudV2[Async]: ...
@overload
@classmethod
def current(cls, asynchronous: bool) -> CloudV2: ...
@classmethod
def current(cls, asynchronous: bool) -> CloudV2:
recent: List[weakref.ReferenceType[CloudV2]]
if asynchronous:
recent = cls._recent_async
else:
recent = cls._recent_sync
try:
cloud = recent[-1]()
while cloud is None or cloud.status != "running":
recent.pop()
cloud = recent[-1]()
except IndexError:
if asynchronous:
return cls(asynchronous=True)
else:
return cls(asynchronous=False)
else:
return cloud
@track_context
async def _get_default_instance_types(self, provider: str, gpu: bool = False, arch: str = "x86_64") -> List[str]:
if arch not in ("arm64", "x86_64"):
raise ValueError(f"arch '{arch}' is not supported for default instance types")
if provider == "aws":
if arch == "arm64":
if gpu:
return ["g5g.xlarge"] # has NVIDIA T4G
else:
return ["m7g.xlarge", "m6g.xlarge"]
if gpu:
return ["g4dn.xlarge"]
else:
return ["m6i.xlarge", "m5.xlarge"]
elif provider == "gcp":
if arch != "x86_64":
return ["t2a-standard-4"]
if gpu:
# n1-standard-8 with 30GB of memory might be best, but that's big for a default
return ["n1-standard-4"]
else:
return ["e2-standard-4"]
elif provider == "azure":
if arch != "x86_64":
raise ValueError(f"no default instance type for Azure with {arch} architecture")
if gpu:
raise ValueError("no default GPU instance type for Azure")
return ["Standard_D4_v5"]
else:
raise ValueError(f"unexpected provider {provider}; cannot determine default instance types")
async def _list_dask_scheduler_page(
self,
page: int,
workspace: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = None,
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
kwargs = {}
if since:
kwargs["since"] = parse_timedelta(since)
if user:
kwargs["user"] = user
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{workspace}/clusters/list",
params={
"limit": page_size,
"offset": page_size * page,
**kwargs,
},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@track_context
async def _list_dask_scheduler(
self,
workspace: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = None,
):
return await self._depaginate_list(
self._list_dask_scheduler_page,
workspace=workspace,
since=since,
user=user,
)
@overload
def list_dask_scheduler(
self: Cloud[Sync],
account: Optional[str] = None,
workspace: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = None,
) -> list: ...
@overload
def list_dask_scheduler(
self: Cloud[Async],
account: Optional[str] = None,
workspace: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = "",
) -> Awaitable[list]: ...
def list_dask_scheduler(
self,
account: Optional[str] = None,
workspace: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = "",
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_dask_scheduler, workspace or account, since=since, user=user)
async def _list_computations(
self, cluster_id: Optional[int] = None, scheduler_id: Optional[int] = None, workspace: Optional[str] = None
):
return await self._depaginate_list(
self._list_computations_page, cluster_id=cluster_id, scheduler_id=scheduler_id, workspace=workspace
)
async def _list_computations_page(
self,
page: int,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
workspace: Optional[str] = None,
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
if not scheduler_id and not cluster_id:
raise ValueError("either cluster_id or scheduler_id must be specified")
api = (
f"/api/v2/analytics/{workspace}/{scheduler_id}/computations/list"
if scheduler_id
else f"/api/v2/analytics/{workspace}/cluster/{cluster_id}/computations/list"
)
response = await self._do_request(
"GET",
self.server + api,
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@overload
def list_computations(
self: Cloud[Sync],
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> list: ...
@overload
def list_computations(
self: Cloud[Async],
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> Awaitable[list]: ...
def list_computations(
self,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> Union[list, Awaitable[list]]:
return self._sync(
self._list_computations, cluster_id=cluster_id, scheduler_id=scheduler_id, workspace=workspace or account
)
def list_exceptions(
self,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
workspace: Optional[str] = None,
since: Optional[str] = None,
user: Optional[str] = None,
) -> Union[list, Awaitable[list]]:
return self._sync(
self._list_exceptions,
cluster_id=cluster_id,
scheduler_id=scheduler_id,
workspace=workspace or account,
since=since,
user=user,
)
async def _list_exceptions(
self,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
workspace: Optional[str] = None,
since: Optional[str] = None,
user: Optional[str] = None,
):
return await self._depaginate_list(
self._list_exceptions_page,
cluster_id=cluster_id,
scheduler_id=scheduler_id,
workspace=workspace,
since=since,
user=user,
)
async def _list_exceptions_page(
self,
page: int,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
workspace: Optional[str] = None,
since: Optional[str] = None,
user: Optional[str] = None,
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
kwargs = {}
if since:
kwargs["since"] = parse_timedelta(since)
if user:
kwargs["user"] = user
if cluster_id:
kwargs["cluster"] = cluster_id
if scheduler_id:
kwargs["scheduler"] = scheduler_id
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{workspace}/exceptions/list",
params={"limit": page_size, "offset": page_size * page, **kwargs},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
async def _list_events_page(
self,
page: int,
cluster_id: int,
workspace: Optional[str] = None,
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{workspace}/{cluster_id}/events/list",
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
async def _list_events(self, cluster_id: int, workspace: Optional[str] = None):
return await self._depaginate_list(self._list_events_page, cluster_id=cluster_id, workspace=workspace)
def list_events(
self,
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_events, cluster_id, workspace or account)
async def _send_state(self, cluster_id: int, desired_status: str, workspace: Optional[str] = None):
workspace = workspace or self.default_workspace
response = await self._do_request(
"POST",
self.server + f"/api/v2/analytics/{workspace}/{cluster_id}/desired-state",
json={"desired_status": desired_status},
)
if response.status >= 400:
await handle_api_exception(response)
def send_state(
self,
cluster_id: int,
desired_status: str,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> Union[None, Awaitable[None]]:
return self._sync(self._send_state, cluster_id, desired_status, workspace or account)
@track_context
async def _list_clusters(
self, workspace: Optional[str] = None, max_pages: Optional[int] = None, just_mine: bool = False
):
return await self._depaginate_list(
self._list_clusters_page, workspace=workspace, max_pages=max_pages, just_mine=just_mine
)
@overload
def list_clusters(
self: Cloud[Sync],
account: Optional[str] = None,
workspace: Optional[str] = None,
max_pages: Optional[int] = None,
just_mine: bool = False,
) -> list: ...
@overload
def list_clusters(
self: Cloud[Async],
account: Optional[str] = None,
workspace: Optional[str] = None,
max_pages: Optional[int] = None,
just_mine: bool = False,
) -> Awaitable[list]: ...
@list_docstring
def list_clusters(
self,
account: Optional[str] = None,
workspace: Optional[str] = None,
max_pages: Optional[int] = None,
just_mine: bool = False,
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_clusters, workspace=workspace or account, max_pages=max_pages, just_mine=just_mine)
async def _list_clusters_page(
self, page: int, workspace: Optional[str] = None, just_mine: bool = False
) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/",
params={"limit": page_size, "offset": page_size * page, "just_mine": "1" if just_mine else "0"},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@staticmethod
async def _depaginate_list(
func: Callable[..., Awaitable[Tuple[list, bool]]],
max_pages: Optional[int] = None,
*args,
**kwargs,
) -> list:
results_all = []
page = 0
while True:
kwargs["page"] = page
results, next = await func(*args, **kwargs)
results_all += results
page += 1
if (not results) or next is None:
break
# page is the number of pages we've already fetched (since 0-indexed)
if max_pages and page >= max_pages:
break
return results_all
@track_context
async def _create_package_sync_env(
self,
packages: List[ResolvedPackageInfo],
progress: Optional[Progress] = None,
workspace: Optional[str] = None,
gpu_enabled: bool = False,
architecture: ArchitectureTypesEnum = ArchitectureTypesEnum.X86_64,
region_name: Optional[str] = None,
use_uv_installer: bool = True,
) -> SoftwareEnvironmentAlias:
workspace = workspace or self.default_workspace
prepared_packages: List[PackageSchema] = []
for pkg in packages:
if pkg["sdist"] and pkg["md5"]:
with simple_progress(f"Uploading {pkg['name']}", progress=progress):
file_id = await self._create_senv_package(
pkg["sdist"],
contents_md5=pkg["md5"],
workspace=workspace,
region_name=region_name,
)
else:
file_id = None
prepared_packages.append({
"name": pkg["name"],
"source": pkg["source"],
"channel": pkg["channel"],
"conda_name": pkg["conda_name"],
"specifier": pkg["specifier"],
"include": pkg["include"],
"client_version": pkg["client_version"],
"file": file_id,
})
with simple_progress("Requesting package sync build", progress=progress):
result = await self._create_software_environment_v2(
senv={
"packages": prepared_packages,
"raw_pip": None,
"raw_conda": None,
},
workspace=workspace,
architecture=architecture,
gpu_enabled=gpu_enabled,
region_name=region_name,
use_uv_installer=use_uv_installer,
)
return result
@track_context
async def _create_custom_certificate(self, subdomain: str, workspace: Optional[str] = None):
workspace = workspace or self.default_workspace
response = await self._do_request(
"POST",
self.server + f"/api/v2/clusters/account/{workspace}/https-certificate",
json={"subdomain": subdomain},
)
if response.status >= 400:
await handle_api_exception(response)
async def _check_custom_certificate(self, subdomain: str, workspace: Optional[str] = None):
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/https-certificate/{subdomain}",
)
if response.status >= 400:
await handle_api_exception(response)
response_json = await response.json()
cert_status = response_json.get("status")
return cert_status
@track_context
async def _create_cluster(
self,
# todo: make name optional and pick one for them, like pre-declarative?
# https://gitlab.com/coiled/cloud/-/issues/4305
name: str,
*,
software_environment: Optional[str] = None,
senv_v2_id: Optional[int] = None,
worker_class: Optional[str] = None,
worker_options: Optional[dict] = None,
scheduler_options: Optional[dict] = None,
workspace: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
dask_config: Optional[Dict] = None,
scheduler_vm_types: Optional[list] = None,
gcp_worker_gpu_type: Optional[str] = None,
gcp_worker_gpu_count: Optional[int] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
worker_disk_throughput: Optional[int] = None,
scheduler_disk_size: Optional[int] = None,
backend_options: Optional[Union[AWSOptions, GCPOptions, dict]] = None,
use_scheduler_public_ip: Optional[bool] = None,
use_dashboard_https: Optional[bool] = None,
private_to_creator: Optional[bool] = None,
extra_worker_on_scheduler: Optional[bool] = None,
n_worker_specs_per_host: Optional[int] = None,
custom_subdomain: Optional[str] = None,
batch_job_ids: Optional[List[int]] = None,
extra_user_container: Optional[str] = None,
) -> Tuple[int, bool]:
# TODO (Declarative): support these args, or decide not to
# https://gitlab.com/coiled/cloud/-/issues/4305
workspace = workspace or self.default_workspace
account, name = self._normalize_name(name, context_workspace=workspace, allow_uppercase=True)
await self._verify_workspace(account)
data = {
"name": name,
"workers": workers,
"worker_instance_types": worker_vm_types,
"scheduler_instance_types": scheduler_vm_types,
"worker_options": worker_options,
"worker_class": worker_class,
"worker_disk_size": worker_disk_size,
"worker_disk_throughput": worker_disk_throughput,
"scheduler_disk_size": scheduler_disk_size,
"scheduler_options": scheduler_options,
"environ": environ,
"tags": tags,
"dask_config": dask_config,
"private_to_creator": private_to_creator,
"env_id": senv_v2_id,
"env_name": software_environment,
"extra_worker_on_scheduler": extra_worker_on_scheduler,
"n_worker_specs_per_host": n_worker_specs_per_host,
"use_aws_creds_endpoint": dask.config.get("coiled.use_aws_creds_endpoint", False),
"custom_subdomain": custom_subdomain,
"batch_job_ids": batch_job_ids,
"extra_user_container": extra_user_container,
}
backend_options = backend_options if backend_options else {}
if gcp_worker_gpu_type is not None:
# for backwards compatibility with v1 options
backend_options = {
**backend_options,
"worker_accelerator_count": gcp_worker_gpu_count or 1,
"worker_accelerator_type": gcp_worker_gpu_type,
}
elif gcp_worker_gpu_count:
# not ideal but v1 only supported T4 and `worker_gpu=1` would give you one
backend_options = {
**backend_options,
"worker_accelerator_count": gcp_worker_gpu_count,
"worker_accelerator_type": "nvidia-tesla-t4",
}
if use_scheduler_public_ip is False:
if "use_dashboard_public_ip" not in backend_options and not use_dashboard_https:
backend_options["use_dashboard_public_ip"] = False
if use_dashboard_https is False:
if "use_dashboard_https" not in backend_options:
backend_options["use_dashboard_https"] = False
if backend_options:
# for backwards compatibility with v1 options
if "region" in backend_options and "region_name" not in backend_options:
backend_options["region_name"] = backend_options["region"] # type: ignore
del backend_options["region"] # type: ignore
if "zone" in backend_options and "zone_name" not in backend_options:
backend_options["zone_name"] = backend_options["zone"] # type: ignore
del backend_options["zone"] # type: ignore
# firewall just lets you specify a single CIDR block to open for ingress
# we want to support a list of ingress CIDR blocks
if "firewall" in backend_options:
backend_options["ingress"] = [backend_options.pop("firewall")] # type: ignore
# convert the list of ingress rules to the FirewallSpec expected server-side
if "ingress" in backend_options:
fw_spec = {"ingress": backend_options.pop("ingress")}
backend_options["firewall_spec"] = fw_spec # type: ignore
validate_backend_options(backend_options)
data["options"] = backend_options
response = await self._do_request(
"POST",
self.server + f"/api/v2/clusters/account/{account}/",
json=data,
)
response_json = await response.json()
if response.status >= 400:
from .widgets import EXECUTION_CONTEXT
if response_json.get("code") == "NO_CLOUD_SETUP":
server_error_message = response_json.get("message")
error_message = f"{server_error_message} or by running `coiled setup`"
if EXECUTION_CONTEXT == "terminal":
# maybe not interactive so just raise
raise ClusterCreationError(error_message)
else:
# interactive session so let's try running the cloud setup wizard
if await do_setup_wizard():
# the user setup their cloud backend, so let's try creating cluster again!
response = await self._do_request(
"POST",
self.server + f"/api/v2/clusters/account/{account}/",
json=data,
)
if response.status >= 400:
await handle_api_exception(response) # always raises exception, no return
response_json = await response.json()
else:
raise ClusterCreationError(error_message)
else:
error_class = PermissionsError if response_json.get("code") == PermissionsError.code else ServerError
if "message" in response_json:
raise error_class(response_json["message"])
if "detail" in response_json:
raise error_class(response_json["detail"])
raise error_class(response_json)
return response_json["id"], response_json["existing"]
@overload
def create_cluster(
self: Cloud[Sync],
name: str,
*,
software: Optional[str] = None,
worker_class: Optional[str] = None,
worker_options: Optional[dict] = None,
scheduler_options: Optional[dict] = None,
account: Optional[str] = None,
workspace: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
dask_config: Optional[Dict] = None,
private_to_creator: Optional[bool] = None,
scheduler_vm_types: Optional[list] = None,
worker_gpu_type: Optional[str] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
worker_disk_throughput: Optional[int] = None,
scheduler_disk_size: Optional[int] = None,
backend_options: Optional[Union[dict, AWSOptions, GCPOptions]] = None,
) -> Tuple[int, bool]: ...
@overload
def create_cluster(
self: Cloud[Async],
name: str,
*,
software: Optional[str] = None,
worker_class: Optional[str] = None,
worker_options: Optional[dict] = None,
scheduler_options: Optional[dict] = None,
account: Optional[str] = None,
workspace: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
dask_config: Optional[Dict] = None,
private_to_creator: Optional[bool] = None,
scheduler_vm_types: Optional[list] = None,
worker_gpu_type: Optional[str] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
worker_disk_throughput: Optional[int] = None,
scheduler_disk_size: Optional[int] = None,
backend_options: Optional[Union[dict, AWSOptions, GCPOptions]] = None,
) -> Awaitable[Tuple[int, bool]]: ...
def create_cluster(
self,
name: str,
*,
software: Optional[str] = None,
worker_class: Optional[str] = None,
worker_options: Optional[dict] = None,
scheduler_options: Optional[dict] = None,
account: Optional[str] = None,
workspace: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
private_to_creator: Optional[bool] = None,
dask_config: Optional[Dict] = None,
scheduler_vm_types: Optional[list] = None,
worker_gpu_type: Optional[str] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
worker_disk_throughput: Optional[int] = None,
scheduler_disk_size: Optional[int] = None,
backend_options: Optional[Union[dict, AWSOptions, GCPOptions]] = None,
) -> Union[Tuple[int, bool], Awaitable[Tuple[int, bool]]]:
return self._sync(
self._create_cluster,
name=name,
software_environment=software,
worker_class=worker_class,
worker_options=worker_options,
scheduler_options=scheduler_options,
workspace=workspace or account,
workers=workers,
environ=environ,
tags=tags,
dask_config=dask_config,
private_to_creator=private_to_creator,
scheduler_vm_types=scheduler_vm_types,
worker_vm_types=worker_vm_types,
gcp_worker_gpu_type=worker_gpu_type,
worker_disk_size=worker_disk_size,
worker_disk_throughput=worker_disk_throughput,
scheduler_disk_size=scheduler_disk_size,
backend_options=backend_options,
)
@track_context
async def _delete_cluster(
self, cluster_id: int, workspace: Optional[str] = None, reason: Optional[str] = None
) -> None:
workspace = workspace or self.default_workspace
route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}"
if reason:
params = {"reason": reason}
else:
params = None
response = await self._do_request_idempotent(
"DELETE",
self.server + route,
params=params,
)
if response.status >= 400:
await handle_api_exception(response)
else:
# multiple deletes sometimes fail if we don't await response here
await response.json()
logger.info(f"Cluster {cluster_id} deleted successfully.")
@overload
def delete_cluster(
self: Cloud[Sync],
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
reason: Optional[str] = None,
) -> None: ...
@overload
def delete_cluster(
self: Cloud[Async],
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
reason: Optional[str] = None,
) -> Awaitable[None]: ...
@delete_docstring # TODO: this docstring erroneously says "Name of cluster" when it really accepts an ID
def delete_cluster(
self,
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
reason: Optional[str] = None,
) -> Optional[Awaitable[None]]:
return self._sync(self._delete_cluster, cluster_id=cluster_id, workspace=workspace or account, reason=reason)
async def _get_cluster_state(self, cluster_id: int, workspace: Optional[str] = None) -> dict:
workspace = workspace or self.default_workspace
# Make request directly instead of using `_do_request` because we don't want any retries.
# Retry logic doesn't make sense because this is called by (frequent) period callback, so we'll just wait
# for next periodic callback call, otherwise retries will overlap with the periodic callback and build up.
session = self._ensure_session()
response = await session.request(
"GET", self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/state"
)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
async def _get_cluster_details(self, cluster_id: int, workspace: Optional[str] = None):
workspace = workspace or self.default_workspace
r = await self._do_request_idempotent(
"GET", self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}"
)
if r.status >= 400:
await handle_api_exception(r)
return await r.json()
def _get_cluster_details_synced(self, cluster_id: int, workspace: Optional[str] = None):
return self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
workspace=workspace,
)
def _cluster_grafana_url(self, cluster_id: int, workspace: Optional[str] = None):
"""for internal Coiled use"""
workspace = workspace or self.default_workspace
details = self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
workspace=workspace,
)
return get_grafana_url(details, account=workspace, cluster_id=cluster_id)
def cluster_details(self, cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None):
details = self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
workspace=workspace or account,
)
state_keys = ["state", "reason", "updated"]
def get_state(state: dict):
return {k: v for k, v in state.items() if k in state_keys}
def get_instance(instance):
if instance is None:
return None
else:
return {
"id": instance["id"],
"created": instance["created"],
"name": instance["name"],
"public_ip_address": instance["public_ip_address"],
"private_ip_address": instance["private_ip_address"],
"current_state": get_state(instance["current_state"]),
}
def get_process(process: dict):
if process is None:
return None
else:
return {
"created": process["created"],
"name": process["name"],
"current_state": get_state(process["current_state"]),
"instance": get_instance(process["instance"]),
}
return {
"id": details["id"],
"name": details["name"],
"workers": [get_process(w) for w in details["workers"]],
"scheduler": get_process(details["scheduler"]),
"current_state": get_state(details["current_state"]),
"created": details["created"],
}
async def _get_workers_page(self, cluster_id: int, page: int, workspace: Optional[str] = None) -> Tuple[list, bool]:
page_size = 100
workspace = workspace or self.default_workspace
response = await self._do_request(
"GET",
self.server + f"/api/v2/workers/account/{workspace}/cluster/{cluster_id}/",
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@track_context
async def _get_worker_names(
self,
workspace: str,
cluster_id: int,
statuses: Optional[List[ProcessStateEnum]] = None,
) -> Set[str]:
worker_infos = await self._depaginate_list(self._get_workers_page, cluster_id=cluster_id, workspace=workspace)
logger.debug(f"workers: {worker_infos}")
return {w["name"] for w in worker_infos if statuses is None or w["current_state"]["state"] in statuses}
@track_context
async def _security(self, cluster_id: int, workspace: Optional[str] = None, client_wants_public_ip: bool = True):
cluster_info = await self._get_cluster_details(cluster_id=cluster_id, workspace=workspace)
if ProcessStateEnum(cluster_info["scheduler"]["current_state"]["state"]) != ProcessStateEnum.started:
scheduler_state = cluster_info["scheduler"]["current_state"]["state"]
raise RuntimeError(
f"Cannot get security info for cluster {cluster_id}, scheduler state is {scheduler_state}"
)
public_ip = cluster_info["scheduler"]["instance"]["public_ip_address"]
private_ip = cluster_info["scheduler"]["instance"]["private_ip_address"]
tls_cert = cluster_info["cluster_options"]["tls_cert"]
tls_key = cluster_info["cluster_options"]["tls_key"]
scheduler_port = cluster_info["scheduler_port"]
dashboard_address = cluster_info["scheduler"]["dashboard_address"]
give_scheduler_public_ip = cluster_info["cluster_infra"]["give_scheduler_public_ip"]
private_address = f"tls://{private_ip}:{scheduler_port}"
public_address = f"tls://{public_ip}:{scheduler_port}"
use_public_address = give_scheduler_public_ip and client_wants_public_ip
if use_public_address:
if not public_ip:
raise RuntimeError(
"Your Coiled client is configured to use the public IP address, but the scheduler VM does not "
"have a public IP address.\n\n"
"If you're expecting to connect on private IP address, you can run\n"
" coiled config set coiled.use_scheduler_public_ip False\n"
"to configure your local Client to use the private IP address, "
"or contact support@coiled.io if you'd like help."
)
address_to_use = public_address
else:
address_to_use = private_address
logger.info(f"Connecting to scheduler on its internal address: {address_to_use}")
# TODO (Declarative): pass extra_conn_args if we care about proxying through Coiled to the scheduler
security = GatewaySecurity(tls_key, tls_cert)
return security, {
"address_to_use": address_to_use,
"private_address": private_address,
"public_address": public_address,
"dashboard_address": dashboard_address,
}
@track_context
async def _requested_workers(self, cluster_id: int, account: Optional[str] = None) -> Set[str]:
raise NotImplementedError("TODO")
@overload
def requested_workers(self: Cloud[Sync], cluster_id: int, account: Optional[str] = None) -> Set[str]: ...
@track_context
async def _get_cluster_by_name(self, name: str, workspace: Optional[str] = None) -> int:
workspace, name = self._normalize_name(
name, context_workspace=workspace, allow_uppercase=True, property_name="cluster name"
)
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/name/{name}",
)
if response.status == 404:
raise DoesNotExist
elif response.status >= 400:
await handle_api_exception(response)
cluster = await response.json()
return cluster["id"]
@overload
def get_cluster_by_name(
self: Cloud[Sync],
name: str,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> int: ...
@overload
def get_cluster_by_name(
self: Cloud[Async],
name: str,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> Awaitable[int]: ...
def get_cluster_by_name(
self,
name: str,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> Union[int, Awaitable[int]]:
return self._sync(
self._get_cluster_by_name,
name=name,
workspace=workspace or account,
)
@track_context
async def _cluster_status(
self,
cluster_id: int,
account: Optional[str] = None,
exclude_stopped: bool = True,
) -> dict:
raise NotImplementedError("TODO?")
@track_context
async def _get_cluster_states_declarative(
self,
cluster_id: int,
workspace: Optional[str] = None,
start_time: Optional[datetime.datetime] = None,
) -> int:
workspace = workspace or self.default_workspace
params = {"start_time": start_time.isoformat()} if start_time is not None else {}
response = await self._do_request_idempotent(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/states",
params=params,
)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def get_cluster_states(
self,
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
start_time: Optional[datetime.datetime] = None,
) -> Union[int, Awaitable[int]]:
return self._sync(
self._get_cluster_states_declarative,
cluster_id=cluster_id,
workspace=workspace or account,
start_time=start_time,
)
def get_clusters_by_name(
self,
name: str,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> List[dict]:
"""Get all clusters matching name."""
return self._sync(
self._get_clusters_by_name,
name=name,
workspace=workspace or account,
)
@track_context
async def _get_clusters_by_name(self, name: str, workspace: Optional[str] = None) -> List[dict]:
workspace, name = self._normalize_name(name, context_workspace=workspace, allow_uppercase=True)
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{workspace}/",
params={"name": name},
)
if response.status == 404:
raise DoesNotExist
elif response.status >= 400:
await handle_api_exception(response)
cluster = await response.json()
return cluster
@overload
def cluster_logs(
self: Cloud[Sync],
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Logs: ...
@overload
def cluster_logs(
self: Cloud[Async],
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Awaitable[Logs]: ...
@track_context
async def _cluster_logs(
self,
cluster_id: int,
workspace: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Logs:
def is_errored(process):
process_state, instance_state = get_process_instance_state(process)
return process_state == ProcessStateEnum.error or instance_state == InstanceStateEnum.error
workspace = workspace or self.default_workspace
# hits endpoint in order to get scheduler and worker instance names
cluster_info = await self._get_cluster_details(cluster_id=cluster_id, workspace=workspace)
try:
scheduler_name = cluster_info["scheduler"]["instance"]["name"]
except (TypeError, KeyError):
# no scheduler instance name in cluster info
logger.warning("No scheduler found when attempting to retrieve cluster logs.")
scheduler_name = None
worker_names = [
worker["instance"]["name"]
for worker in cluster_info["workers"]
if worker["instance"] and (not errors_only or is_errored(worker))
]
LabeledInstance = namedtuple("LabeledInstance", ("name", "label"))
instances = []
if scheduler and scheduler_name and (not errors_only or is_errored(cluster_info["scheduler"])):
instances.append(LabeledInstance(scheduler_name, "Scheduler"))
if workers and worker_names:
instances.extend([LabeledInstance(worker_name, worker_name) for worker_name in worker_names])
async def instance_log_with_semaphor(semaphor, **kwargs):
async with semaphor:
return await self._instance_logs(**kwargs)
# only get 100 logs at a time; the limit here is redundant since aiohttp session already limits concurrent
# connections but let's be safe just in case
semaphor = asyncio.Semaphore(value=100)
results = await asyncio.gather(*[
instance_log_with_semaphor(semaphor=semaphor, workspace=workspace, instance_name=inst.name)
for inst in instances
])
out = {
instance_label: instance_log
for (_, instance_label), instance_log in zip(instances, results)
if len(instance_log)
}
return Logs(out)
def cluster_logs(
self,
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Union[Logs, Awaitable[Logs]]:
return self._sync(
self._cluster_logs,
cluster_id=cluster_id,
workspace=workspace or account,
scheduler=scheduler,
workers=workers,
errors_only=errors_only,
)
async def _instance_logs(self, workspace: str, instance_name: str, safe=True) -> Log:
response = await self._do_request(
"GET",
self.server + "/api/v2/instances/{}/instance/{}/logs".format(workspace, instance_name),
)
if response.status >= 400:
if safe:
logger.warning(f"Error retrieving logs for {instance_name}")
return Log()
await handle_api_exception(response)
data = await response.json()
messages = "\n".join(logline.get("message", "") for logline in data)
return Log(messages)
@overload
def requested_workers(
self: Cloud[Async], cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
) -> Awaitable[Set[str]]: ...
def requested_workers(
self, cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
) -> Union[
Set[str],
Awaitable[Set[str]],
]:
return self._sync(self._requested_workers, cluster_id, workspace or account)
@overload
def scale_up(
self: Cloud[Sync], cluster_id: int, n: int, account: Optional[str] = None, workspace: Optional[str] = None
) -> Optional[Dict]: ...
@overload
def scale_up(
self: Cloud[Async], cluster_id: int, n: int, account: Optional[str] = None, workspace: Optional[str] = None
) -> Awaitable[Optional[Dict]]: ...
def scale_up(
self, cluster_id: int, n: int, account: Optional[str] = None, workspace: Optional[str] = None
) -> Union[Optional[Dict], Awaitable[Optional[Dict]]]:
"""Scale cluster to ``n`` workers
Parameters
----------
cluster_id
Unique cluster identifier.
n
Number of workers to scale cluster size to.
account
**DEPRECATED**. Use ``workspace`` instead.
workspace
The Coiled workspace (previously "account") to use. If not specified,
will check the ``coiled.workspace`` or ``coiled.account`` configuration values,
or will use your default workspace if those aren't set.
"""
return self._sync(self._scale_up, cluster_id, n, workspace or account)
@overload
def scale_down(
self: Cloud[Sync],
cluster_id: int,
workers: Set[str],
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> None: ...
@overload
def scale_down(
self: Cloud[Async],
cluster_id: int,
workers: Set[str],
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> Awaitable[None]: ...
def scale_down(
self,
cluster_id: int,
workers: Set[str],
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> Optional[Awaitable[None]]:
"""Scale cluster to ``n`` workers
Parameters
----------
cluster_id
Unique cluster identifier.
workers
Set of workers to scale down to.
account
**DEPRECATED**. Use ``workspace`` instead.
workspace
The Coiled workspace (previously "account") to use. If not specified,
will check the ``coiled.workspace`` or ``coiled.account`` configuration values,
or will use your default workspace if those aren't set.
"""
return self._sync(self._scale_down, cluster_id, workers, workspace or account)
@track_context
async def _better_cluster_logs(
self,
cluster_id: int,
workspace: Optional[str] = None,
instance_ids: Optional[List[int]] = None,
dask: bool = False,
system: bool = False,
since_ms: Optional[int] = None,
until_ms: Optional[int] = None,
filter: Optional[str] = None,
):
workspace = workspace or self.default_workspace
url_params = {}
if dask:
url_params["dask"] = "True"
if system:
url_params["system"] = "True"
if since_ms:
url_params["since_ms"] = f"{since_ms}"
if until_ms:
url_params["until_ms"] = f"{until_ms}"
if filter:
url_params["filter_pattern"] = f"{filter}"
if instance_ids:
url_params["instance_ids"] = ",".join(map(str, instance_ids))
url_path = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/better-logs"
response = await self._do_request(
"GET",
f"{self.server}{url_path}",
params=url_params,
)
if response.status >= 400:
await handle_api_exception(response)
data = await response.json()
return data
def better_cluster_logs(
self,
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
instance_ids: Optional[List[int]] = None,
dask: bool = False,
system: bool = False,
since_ms: Optional[int] = None,
until_ms: Optional[int] = None,
filter: Optional[str] = None,
) -> Logs:
return self._sync(
self._better_cluster_logs,
cluster_id=cluster_id,
workspace=workspace or account,
instance_ids=instance_ids,
dask=dask,
system=system,
since_ms=since_ms,
until_ms=until_ms,
filter=filter,
)
@track_context
async def _scale_up(
self, cluster_id: int, n: int, workspace: Optional[str] = None, reason: Optional[str] = None
) -> Dict:
"""
Increases the number of workers by ``n``.
"""
workspace = workspace or self.default_workspace
data = {"n_workers": n}
if reason:
# pyright is annoying
data["reason"] = reason # type: ignore
response = await self._do_request(
"POST", f"{self.server}/api/v2/workers/account/{workspace}/cluster/{cluster_id}/", json=data
)
if response.status >= 400:
await handle_api_exception(response)
workers_info = await response.json()
return {"workers": {w["name"] for w in workers_info}}
@track_context
async def _scale_down(
self, cluster_id: int, workers: Iterable[str], workspace: Optional[str] = None, reason: Optional[str] = None
) -> None:
workspace = workspace or self.default_workspace
workers = list(workers) # yarl, used by aiohttp, expects list (not set) of strings
reason_dict = {"reason": reason} if reason else {}
response = await self._do_request(
"DELETE",
f"{self.server}/api/v2/workers/account/{workspace}/cluster/{cluster_id}/",
params={"name": workers, **reason_dict},
)
if response.status >= 400:
await handle_api_exception(response)
@overload
def security(
self: Cloud[Sync], cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
) -> Tuple[dask.distributed.Security, dict]: ...
@overload
def security(
self: Cloud[Async], cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
) -> Awaitable[Tuple[dask.distributed.Security, dict]]: ...
def security(
self, cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
) -> Union[
Tuple[dask.distributed.Security, dict],
Awaitable[Tuple[dask.distributed.Security, dict]],
]:
return self._sync(self._security, cluster_id, workspace or account)
@track_context
async def _fetch_package_levels(self, workspace: Optional[str] = None) -> List[PackageLevel]:
workspace = workspace or self.default_workspace
response = await self._do_request("GET", f"{self.server}/api/v2/packages/", params={"account": workspace})
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def get_ssh_key(
self,
cluster_id: int,
workspace: Optional[str] = None,
worker: Optional[str] = None,
) -> dict:
workspace = workspace or self.default_workspace
return self._sync(
self._get_ssh_key,
cluster_id=cluster_id,
workspace=workspace,
worker=worker,
)
@track_context
async def _get_ssh_key(self, cluster_id: int, workspace: str, worker: Optional[str]) -> dict:
workspace = workspace or self.default_workspace
route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/ssh-key"
url = f"{self.server}{route}"
response = await self._do_request("GET", url, params={"worker": worker} if worker else None)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def get_cluster_log_info(
self,
cluster_id: int,
workspace: Optional[str] = None,
) -> dict:
workspace = workspace or self.default_workspace
return self._sync(
self._get_cluster_log_info,
cluster_id=cluster_id,
workspace=workspace,
)
@track_context
async def _get_cluster_log_info(
self,
cluster_id: int,
workspace: str,
) -> dict:
workspace = workspace or self.default_workspace
route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/log-info"
url = f"{self.server}{route}"
response = await self._do_request("GET", url)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
@track_context
async def _get_cluster_aggregated_metric(
self,
cluster_id: int,
workspace: Optional[str],
query: str,
over_time: str,
start_ts: Optional[int],
end_ts: Optional[int],
):
workspace = workspace or self.default_workspace
route = f"/api/v2/metrics/account/{workspace}/cluster/{cluster_id}/value"
url = f"{self.server}{route}"
params = {"query": query, "over_time": over_time}
if start_ts:
params["start_ts"] = str(start_ts)
if end_ts:
params["end_ts"] = str(end_ts)
response = await self._do_request("GET", url, params=params)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
@track_context
async def _add_cluster_span(self, cluster_id: int, workspace: Optional[str], span_identifier: str, data: dict):
workspace = workspace or self.default_workspace
route = f"/api/v2/analytics/{workspace}/cluster/{cluster_id}/span/{span_identifier}"
url = f"{self.server}{route}"
response = await self._do_request("POST", url, json=data)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def _sync_request(self, route, method: str = "GET", json_result: bool = False, **kwargs):
url = f"{self.server}{route}"
response = self._sync(self._do_request, url=url, method=method, **kwargs)
if response.status >= 400:
raise ServerError(f"{url} returned {response.status}")
async def get_result(r):
return await (r.json() if json_result else r.text())
return self._sync(
get_result,
response,
)
Cloud = CloudV2
[docs]
def cluster_logs(
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
):
"""
Returns cluster logs as a dictionary, with a key for the scheduler and each worker.
.. versionchanged:: 0.2.0
``cluster_name`` is no longer accepted, use ``cluster_id`` instead.
"""
with CloudV2() as cloud:
return cloud.cluster_logs(
cluster_id=cluster_id,
workspace=workspace or account,
scheduler=scheduler,
workers=workers,
errors_only=errors_only,
)
def better_cluster_logs(
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
instance_ids: Optional[List[int]] = None,
dask: bool = False,
system: bool = False,
since_ms: Optional[int] = None,
until_ms: Optional[int] = None,
filter: Optional[str] = None,
):
"""
Pull logs for the cluster using better endpoint.
Logs for recent clusters are split between system and container (dask), you can get
either or both (or none).
since_ms and until_ms are both inclusive (you'll get logs with timestamp matching those).
"""
with Cloud() as cloud:
return cloud.better_cluster_logs(
cluster_id=cluster_id,
workspace=workspace or account,
instance_ids=instance_ids,
dask=dask,
system=system,
since_ms=since_ms,
until_ms=until_ms,
filter=filter,
)
def cluster_details(
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
) -> dict:
"""
Get details of a cluster as a dictionary.
"""
with CloudV2() as cloud:
return cloud.cluster_details(
cluster_id=cluster_id,
workspace=workspace or account,
)
def log_cluster_debug_info(
cluster_id: int,
account: Optional[str] = None,
workspace: Optional[str] = None,
):
with CloudV2() as cloud:
details = cloud.cluster_details(cluster_id, workspace or account)
logger.debug("Cluster details:")
logger.debug(json.dumps(details, indent=2))
states_by_type = cloud.get_cluster_states(cluster_id, workspace or account)
logger.debug("cluster state history:")
log_states(flatten_log_states(states_by_type), level=logging.DEBUG)
def create_cluster(
name: str,
*,
software: Optional[str] = None,
worker_options: Optional[dict] = None,
scheduler_options: Optional[dict] = None,
account: Optional[str] = None,
workspace: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
dask_config: Optional[Dict] = None,
private_to_creator: Optional[bool] = None,
scheduler_vm_types: Optional[list] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
worker_disk_throughput: Optional[int] = None,
scheduler_disk_size: Optional[int] = None,
backend_options: Optional[Union[dict, AWSOptions, GCPOptions]] = None,
) -> int:
"""Create a cluster
Parameters
---------
name
Name of cluster.
software
Identifier of the software environment to use, in the format (<account>/)<name>. If the software environment
is owned by the same account as that passed into "account", the (<account>/) prefix is optional.
For example, suppose your account is "wondercorp", but your friends at "friendlycorp" have an environment
named "xgboost" that you want to use; you can specify this with "friendlycorp/xgboost". If you simply
entered "xgboost", this is shorthand for "wondercorp/xgboost".
The "name" portion of (<account>/)<name> can only contain ASCII letters, hyphens and underscores.
worker_options
Mapping with keyword arguments to pass to ``worker_class``. Defaults to ``{}``.
scheduler_options
Mapping with keyword arguments to pass to the Scheduler ``__init__``. Defaults to ``{}``.
account
**DEPRECATED**. Use ``workspace`` instead.
workspace
The Coiled workspace (previously "account") to use. If not specified,
will check the ``coiled.workspace`` or ``coiled.account`` configuration values,
or will use your default workspace if those aren't set.
workers
Number of workers we to launch.
environ
Dictionary of environment variables.
tags
Dictionary of instance tags
dask_config
Dictionary of dask config to put on cluster
See Also
--------
coiled.Cluster
"""
with CloudV2(account=workspace or account) as cloud:
cluster, existing = cloud.create_cluster(
name=name,
software=software,
worker_options=worker_options,
scheduler_options=scheduler_options,
workspace=workspace or account,
workers=workers,
environ=environ,
tags=tags,
dask_config=dask_config,
private_to_creator=private_to_creator,
backend_options=backend_options,
worker_vm_types=worker_vm_types,
worker_disk_size=worker_disk_size,
worker_disk_throughput=worker_disk_throughput,
scheduler_disk_size=scheduler_disk_size,
scheduler_vm_types=scheduler_vm_types,
)
return cluster
[docs]
@list_docstring
def list_clusters(account=None, workspace=None, max_pages: Optional[int] = None):
with CloudV2() as cloud:
return cloud.list_clusters(workspace=workspace or account, max_pages=max_pages)
[docs]
@delete_docstring
def delete_cluster(name: str, account: Optional[str] = None, workspace: Optional[str] = None):
with CloudV2() as cloud:
cluster_id = cloud.get_cluster_by_name(name=name, workspace=workspace or account)
if cluster_id is not None:
return cloud.delete_cluster(cluster_id=cluster_id, workspace=workspace or account)
def create_package_sync_software_env(
workspace=None, gpu=False, arm=False, strict=False, force_rich_widget=False, **kwargs
):
from coiled.capture_environment import scan_and_create
with Cloud(workspace=workspace) as cloud:
package_sync_env_alias = cloud._sync(
scan_and_create,
cloud=cloud,
workspace=workspace,
gpu_enabled=gpu,
architecture=ArchitectureTypesEnum.ARM64 if arm else ArchitectureTypesEnum.X86_64,
package_sync_strict=strict,
force_rich_widget=force_rich_widget,
**kwargs,
)
return package_sync_env_alias