from __future__ import annotations
import asyncio
import base64
import datetime
import json
import logging
import sys
import weakref
from collections import namedtuple
from hashlib import md5
from pathlib import Path
from typing import (
Awaitable,
BinaryIO,
Callable,
Dict,
Generic,
List,
NoReturn,
Optional,
Set,
Tuple,
Union,
overload,
)
import backoff
import dask
from aiohttp import ClientResponseError, ClientSession, ContentTypeError
from coiled.types import (
ApproximatePackageRequest,
ApproximatePackageResult,
PackageSchema,
ResolvedPackageInfo,
)
if sys.version_info >= (3, 8):
from typing import TypedDict
else:
from typing_extensions import TypedDict
from dask.utils import parse_timedelta
from distributed.utils import Log, Logs
from coiled.cli.setup.entry import do_setup_wizard
from coiled.context import track_context
from coiled.core import Async
from coiled.core import Cloud as OldCloud
from coiled.core import IsAsynchronous, Sync, delete_docstring, list_docstring
from coiled.errors import ClusterCreationError, DoesNotExist, ServerError
from coiled.types import PackageLevel
from coiled.utils import (
COILED_LOGGER_NAME,
GatewaySecurity,
get_grafana_url,
validate_type,
)
from .states import (
InstanceStateEnum,
ProcessStateEnum,
flatten_log_states,
get_process_instance_state,
log_states,
)
logger = logging.getLogger(COILED_LOGGER_NAME)
def setup_logging(level=logging.INFO):
# only set up logging if there's no log level specified yet on the coiled logger
if logging.getLogger(COILED_LOGGER_NAME).level == 0:
logging.getLogger(COILED_LOGGER_NAME).setLevel(level)
logging.basicConfig()
async def handle_api_exception(response, exception_cls=ServerError) -> NoReturn:
try:
error_body = await response.json()
except ContentTypeError:
raise exception_cls(
f"Unexpected status code ({response.status}) to {response.method}:{response.url}, contact support@coiled.io"
)
if "message" in error_body:
raise exception_cls(error_body["message"])
if "detail" in error_body:
raise exception_cls(error_body["detail"])
raise exception_cls(error_body)
class FirewallOptions(TypedDict):
"""
A dictionary with the following key/value pairs
Parameters
----------
ports
List of ports to open to cidr on the scheduler.
For example, ``[22, 8786]`` opens port 22 for SSH and 8786 for client to Dask connection.
cidr
CIDR block from which to allow access. For example ``0.0.0.0/0`` allows access from any IP address.
"""
ports: List[int]
cidr: str
[docs]class BackendOptions(TypedDict, total=False):
"""
A dictionary with the following key/value pairs
Parameters
----------
region_name
Region name to launch cluster in. For example: us-east-2
zone_name
Zone name to launch cluster in. For example: us-east-2a
firewall
Allows you to specify firewall for scheduler; see :py:class:`FirewallOptions` for details.
ingress
Allows you to specify multiple CIDR blocks (and corresponding ports) to open for ingress
on the scheduler firewall.
spot
Whether to request spot instances.
spot_on_demand_fallback
If requesting spot, whether to request non-spot instances if we get fewer spot instances
than desired.
multizone
Tell the cloud provider to pick zone with best availability, we'll keep workers all in the
same zone, scheduler may or may not be in that zone as well.
"""
region_name: Optional[str]
zone_name: Optional[str]
firewall: Optional[FirewallOptions]
ingress: Optional[List[FirewallOptions]]
spot: Optional[bool]
spot_on_demand_fallback: Optional[bool]
multizone: Optional[bool]
use_dashboard_public_ip: Optional[bool]
send_prometheus_metrics: Optional[bool] # TODO deprecate
prometheus_write: Optional[dict]
class AWSOptions(BackendOptions, total=False):
"""
A dictionary with the following key/value pairs plus any pairs in :py:class:`BackendOptions`
Parameters
----------
keypair_name
AWS Keypair to assign worker/scheduler instances
"""
keypair_name: Optional[str]
spot_replacement: Optional[bool]
use_placement_group: Optional[bool]
class GCPOptions(BackendOptions, total=False):
scheduler_accelerator_count: Optional[int]
scheduler_accelerator_type: Optional[str]
worker_accelerator_count: Optional[int]
worker_accelerator_type: Optional[str]
BackendOptionTypes = [AWSOptions, GCPOptions]
class CloudBeta(OldCloud, Generic[IsAsynchronous]):
_recent_sync: list[weakref.ReferenceType[CloudBeta[Sync]]] = list()
_recent_async: list[weakref.ReferenceType[CloudBeta[Async]]] = list()
# just overriding to get the right signature (CloudBeta, not Cloud)
def __enter__(self: CloudBeta[Sync]) -> CloudBeta[Sync]:
return self
def __exit__(self: CloudBeta[Sync], typ, value, tb) -> None:
self.close()
async def __aenter__(self: CloudBeta[Async]) -> CloudBeta[Async]:
return await self._start()
async def __aexit__(self: CloudBeta[Async], typ, value, tb) -> None:
await self._close()
# these overloads are necessary for the typechecker to know that we really have a CloudBeta, not a Cloud
# without them, CloudBeta.current would be typed to return a Cloud
#
# https://www.python.org/dev/peps/pep-0673/ would remove the need for this.
# That PEP also mentions a workaround with type vars, which doesn't work for us because type vars aren't
# subscribtable
@overload
@classmethod
def current(cls, asynchronous: Sync) -> CloudBeta[Sync]:
...
@overload
@classmethod
def current(cls, asynchronous: Async) -> CloudBeta[Async]:
...
@overload
@classmethod
def current(cls, asynchronous: bool) -> CloudBeta:
...
@classmethod
def current(cls, asynchronous: bool) -> CloudBeta:
recent: list[weakref.ReferenceType[CloudBeta]]
if asynchronous:
recent = cls._recent_async
else:
recent = cls._recent_sync
try:
cloud = recent[-1]()
while cloud is None or cloud.status != "running":
recent.pop()
cloud = recent[-1]()
except IndexError:
if asynchronous:
return cls(asynchronous=True)
else:
return cls(asynchronous=False)
else:
return cloud
@track_context
async def _get_default_instance_types(
self, provider: str, gpu: bool = False
) -> List[str]:
if provider == "aws":
if gpu:
return ["g4dn.xlarge"]
else:
return ["t3.xlarge"]
elif provider == "gcp":
if gpu:
# n1-standard-8 with 30GB of memory might be best, but that's big for a default
return ["n1-standard-4"]
else:
return ["e2-standard-4"]
else:
raise ValueError(
f"unexpected provider {provider}; cannot determine default instance types"
)
async def _list_dask_scheduler_page(
self,
page: int,
account: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = None,
) -> Tuple[list, bool]:
page_size = 100
account = account or self.default_account
kwargs = {}
if since:
kwargs["since"] = parse_timedelta(since)
if user:
kwargs["user"] = user
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{account}/clusters/list",
params={
"limit": page_size,
"offset": page_size * page,
**kwargs,
},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@track_context
async def _list_dask_scheduler(
self,
account: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = None,
):
return await self._depaginate_list(
self._list_dask_scheduler_page,
account=account,
since=since,
user=user,
)
@overload
def list_dask_scheduler(
self: Cloud[Sync],
account: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = None,
) -> list:
...
@overload
def list_dask_scheduler(
self: Cloud[Async],
account: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = "",
) -> Awaitable[list]:
...
def list_dask_scheduler(
self,
account: Optional[str] = None,
since: Optional[str] = "7 days",
user: Optional[str] = "",
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_dask_scheduler, account, since=since, user=user)
async def _list_computations(self, cluster_id: int, account: Optional[str] = None):
return await self._depaginate_list(
self._list_computations_page, cluster_id=cluster_id, account=account
)
async def _list_computations_page(
self,
page: int,
cluster_id: int,
account: Optional[str] = None,
) -> Tuple[list, bool]:
page_size = 100
account = account or self.default_account
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{account}/{cluster_id}/computations/list",
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@overload
def list_computations(
self: Cloud[Sync], cluster_id: int, account: Optional[str] = None
) -> list:
...
@overload
def list_computations(
self: Cloud[Async], cluster_id: int, account: Optional[str] = None
) -> Awaitable[list]:
...
def list_computations(
self, cluster_id: int, account: Optional[str] = None
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_computations, cluster_id, account)
@overload
def list_exceptions(
self,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
since: Optional[str] = None,
user: Optional[str] = None,
) -> list:
...
@overload
def list_exceptions(
self,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
since: Optional[str] = None,
user: Optional[str] = None,
) -> Awaitable[list]:
...
def list_exceptions(
self,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
since: Optional[str] = None,
user: Optional[str] = None,
) -> Union[list, Awaitable[list]]:
return self._sync(
self._list_exceptions,
cluster_id=cluster_id,
scheduler_id=scheduler_id,
account=account,
since=since,
user=user,
)
async def _list_exceptions(
self,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
since: Optional[str] = None,
user: Optional[str] = None,
):
return await self._depaginate_list(
self._list_exceptions_page,
cluster_id=cluster_id,
scheduler_id=scheduler_id,
account=account,
since=since,
user=user,
)
async def _list_exceptions_page(
self,
page: int,
cluster_id: Optional[int] = None,
scheduler_id: Optional[int] = None,
account: Optional[str] = None,
since: Optional[str] = None,
user: Optional[str] = None,
) -> Tuple[list, bool]:
page_size = 100
account = account or self.default_account
kwargs = {}
if since:
kwargs["since"] = parse_timedelta(since)
if user:
kwargs["user"] = user
if cluster_id:
kwargs["cluster"] = cluster_id
if scheduler_id:
kwargs["scheduler"] = scheduler_id
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{account}/exceptions/list",
params={"limit": page_size, "offset": page_size * page, **kwargs},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
async def _list_events_page(
self,
page: int,
cluster_id: int,
account: Optional[str] = None,
) -> Tuple[list, bool]:
page_size = 100
account = account or self.default_account
response = await self._do_request(
"GET",
self.server + f"/api/v2/analytics/{account}/{cluster_id}/events/list",
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
async def _list_events(self, cluster_id: int, account: Optional[str] = None):
return await self._depaginate_list(
self._list_events_page, cluster_id=cluster_id, account=account
)
def list_events(
self, cluster_id: int, account: Optional[str] = None
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_events, cluster_id, account)
async def _send_state(
self, cluster_id: int, desired_status: str, account: Optional[str] = None
):
account = account or self.default_account
response = await self._do_request(
"POST",
self.server + f"/api/v2/analytics/{account}/{cluster_id}/desired-state",
json={"desired_status": desired_status},
)
if response.status >= 400:
await handle_api_exception(response)
def send_state(
self, cluster_id: int, desired_status: str, account: Optional[str] = None
) -> Union[None, Awaitable[None]]:
return self._sync(self._send_state, cluster_id, desired_status, account)
@track_context
async def _list_clusters(
self, account: Optional[str] = None, max_pages: Optional[int] = None
):
return await self._depaginate_list(
self._list_clusters_page, account=account, max_pages=max_pages
)
@overload
def list_clusters(
self: Cloud[Sync],
account: Optional[str] = None,
max_pages: Optional[int] = None,
) -> list:
...
@overload
def list_clusters(
self: Cloud[Async],
account: Optional[str] = None,
max_pages: Optional[int] = None,
) -> Awaitable[list]:
...
@list_docstring
def list_clusters(
self, account: Optional[str] = None, max_pages: Optional[int] = None
) -> Union[list, Awaitable[list]]:
return self._sync(self._list_clusters, account, max_pages=max_pages)
async def _list_clusters_page(
self, page: int, account: Optional[str] = None
) -> Tuple[list, bool]:
page_size = 100
account = account or self.default_account
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{account}/",
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@staticmethod
async def _depaginate_list(
func: Callable[..., Awaitable[Tuple[list, bool]]],
max_pages: Optional[int] = None,
*args,
**kwargs,
) -> list:
results_all = []
page = 0
while True:
kwargs["page"] = page
results, next = await func(*args, **kwargs)
results_all += results
page += 1
if (not results) or next is None:
break
# page is the number of pages we've already fetched (since 0-indexed)
if max_pages and page >= max_pages:
break
return results_all
async def _create_package_sync_env(
self, packages: List[ResolvedPackageInfo], account: Optional[str] = None
) -> int:
account = account or self.default_account
prepared_packages: List[PackageSchema] = []
for pkg in packages:
if pkg["sdist"]:
file_id = await self._create_senv_package(
pkg["sdist"], contents_md5=pkg["md5"], account=account
)
else:
file_id = None
prepared_packages.append(
{
"name": pkg["name"],
"source": pkg["source"],
"channel": pkg["channel"],
"conda_name": pkg["conda_name"],
"specifier": pkg["specifier"],
"include": pkg["include"],
"client_version": pkg["client_version"],
"file": file_id,
}
)
return await self._create_software_environment_v2(
senv=prepared_packages, account=account
)
@track_context
async def _create_senv_package(
self, package_file: BinaryIO, contents_md5: str, account: Optional[str] = None
) -> int:
logger.info(f"Starting upload for {package_file}")
package_data = package_file.read()
# s3 expects the md5 to be base64 encoded
wheel_md5 = base64.b64encode(md5(package_data).digest()).decode("utf-8")
account = account or self.default_account
response = await self._do_request(
"POST",
self.server
+ f"/api/v2/software-environment/account/{account}/package-upload",
json={
"name": Path(package_file.name).name,
"md5": contents_md5,
"wheel_md5": wheel_md5,
},
)
if response.status >= 400:
await handle_api_exception(response) # always raises exception, no return
data = await response.json()
if data["should_upload"]:
await self._put_package(
url=data["upload_url"],
package_data=package_data,
file_md5=wheel_md5,
)
else:
logger.info(f"{package_file} MD5 matches existing, skipping upload")
return data["id"]
@backoff.on_exception(
backoff.expo,
ClientResponseError,
max_time=120,
giveup=lambda error: error.status < 500,
)
async def _put_package(self, url: str, package_data: bytes, file_md5: str):
# can't use the default session as it has coiled auth headers
async with ClientSession() as session:
async with session.put(
url=url, data=package_data, headers={"content-md5": file_md5}
) as resp:
resp.raise_for_status()
@track_context
async def _create_software_environment_v2(
self,
senv: List[PackageSchema],
account: Optional[str] = None,
) -> int:
account = account or self.default_account
resp = await self._do_request(
"POST",
self.server + f"/api/v2/software-environment/account/{account}",
json={
"packages": senv,
"md5": md5(
json.dumps(senv, sort_keys=True).encode("utf-8")
).hexdigest(),
},
)
if resp.status >= 400:
await handle_api_exception(resp) # always raises exception, no return
data = await resp.json()
return data["id"]
@track_context
async def _create_cluster(
self,
# todo: make name optional and pick one for them, like pre-declarative?
# https://gitlab.com/coiled/cloud/-/issues/4305
name: str,
*,
software_environment: Optional[str] = None,
senv_v2_id: Optional[int] = None,
worker_class: Optional[str] = None,
worker_options: Optional[dict] = None,
worker_cpu: Optional[int] = None,
worker_memory: Optional[Union[str, List[str]]] = None,
scheduler_class: Optional[str] = None,
scheduler_options: Optional[dict] = None,
scheduler_cpu: Optional[int] = None,
scheduler_memory: Optional[Union[str, List[str]]] = None,
account: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
dask_config: Optional[Dict] = None,
scheduler_vm_types: Optional[list] = None,
gcp_worker_gpu_type: Optional[str] = None,
gcp_worker_gpu_count: Optional[int] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
backend_options: Optional[Union[AWSOptions, GCPOptions, dict]] = None,
use_scheduler_public_ip: Optional[bool] = None,
private_to_creator: Optional[bool] = None,
) -> int:
# TODO (Declarative): support these args, or decide not to
# https://gitlab.com/coiled/cloud/-/issues/4305
if scheduler_class is not None:
raise ValueError("scheduler_class is not supported in beta/new Coiled yet")
account = account or self.default_account
account, name = self._normalize_name(
name,
context_account=account,
allow_uppercase=True,
)
self._verify_account(account)
data = {
"name": name,
"workers": workers,
"worker_instance_types": worker_vm_types,
"scheduler_instance_types": scheduler_vm_types,
"software_environment": software_environment,
"worker_options": worker_options,
"worker_cpu": worker_cpu,
"worker_class": worker_class,
"worker_memory": worker_memory,
"worker_disk_size": worker_disk_size,
"scheduler_options": scheduler_options,
"scheduler_cpu": scheduler_cpu,
"scheduler_memory": scheduler_memory,
"environ": environ,
"tags": tags,
"dask_config": dask_config,
"private_to_creator": private_to_creator,
"env_id": senv_v2_id
# "jupyter_on_scheduler": True,
}
if gcp_worker_gpu_type is not None:
# for backwards compatibility with v1 options
backend_options = backend_options if backend_options else {}
backend_options = {
**backend_options,
"worker_accelerator_count": gcp_worker_gpu_count or 1,
"worker_accelerator_type": gcp_worker_gpu_type,
}
elif gcp_worker_gpu_count:
# not ideal but v1 only supported T4 and `worker_gpu=1` would give you one
backend_options = backend_options if backend_options else {}
backend_options = {
**backend_options,
"worker_accelerator_count": gcp_worker_gpu_count,
"worker_accelerator_type": "nvidia-tesla-t4",
}
if use_scheduler_public_ip is False:
backend_options = backend_options if backend_options else {}
if "use_dashboard_public_ip" not in backend_options:
backend_options["use_dashboard_public_ip"] = False
if backend_options:
# for backwards compatibility with v1 options
if "region" in backend_options and "region_name" not in backend_options:
backend_options["region_name"] = backend_options["region"] # type: ignore
del backend_options["region"] # type: ignore
if "zone" in backend_options and "zone_name" not in backend_options:
backend_options["zone_name"] = backend_options["zone"] # type: ignore
del backend_options["zone"] # type: ignore
# firewall just lets you specify a single CIDR block to open for ingress
# we want to support a list of ingress CIDR blocks
if "firewall" in backend_options:
backend_options["ingress"] = [backend_options.pop("firewall")] # type: ignore
# validate against TypedDicts -- should be better (especially better errors)
if not any((validate_type(t, backend_options) for t in BackendOptionTypes)):
raise ValueError(
"backend_options should be an instance of coiled.BackendOptions"
)
# convert the list of ingress rules to the FirewallSpec expected server-side
if "ingress" in backend_options:
fw_spec = {"ingress": backend_options.pop("ingress")}
backend_options["firewall_spec"] = fw_spec # type: ignore
data["options"] = backend_options
response = await self._do_request(
"POST",
self.server + f"/api/v2/clusters/account/{account}/",
json=data,
)
response_json = await response.json()
if response.status >= 400:
from .widgets import EXECUTION_CONTEXT
if response_json.get("code") == "NO_CLOUD_SETUP":
server_error_message = response_json.get("message")
error_message = (
f"{server_error_message} or by running `coiled setup wizard`"
)
if EXECUTION_CONTEXT == "terminal":
# maybe not interactive so just raise
raise ClusterCreationError(error_message)
else:
# interactive session so let's try running the cloud setup wizard
if do_setup_wizard():
# the user setup their cloud backend, so let's try creating cluster again!
response = await self._do_request(
"POST",
self.server + f"/api/v2/clusters/account/{account}/",
json=data,
)
if response.status >= 400:
await handle_api_exception(
response
) # always raises exception, no return
response_json = await response.json()
else:
raise ClusterCreationError(error_message)
else:
if "message" in response_json:
raise ServerError(response_json["message"])
if "detail" in response_json:
raise ServerError(response_json["detail"])
raise ServerError(response_json)
return response_json["id"]
@overload
def create_cluster(
self: Cloud[Sync],
name: Optional[str] = None,
*,
software: Optional[str] = None,
worker_class: Optional[str] = None,
worker_options: Optional[dict] = None,
worker_cpu: Optional[int] = None,
worker_memory: Optional[int] = None,
scheduler_class: Optional[str] = None,
scheduler_options: Optional[dict] = None,
scheduler_cpu: Optional[int] = None,
scheduler_memory: Optional[int] = None,
account: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
dask_config: Optional[Dict] = None,
private_to_creator: Optional[bool] = None,
scheduler_vm_types: Optional[list] = None,
worker_gpu_type: Optional[str] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
backend_options: Optional[dict | BackendOptions] = None,
) -> int:
...
@overload
def create_cluster(
self: Cloud[Async],
name: Optional[str] = None,
*,
software: Optional[str] = None,
worker_class: Optional[str] = None,
worker_options: Optional[dict] = None,
worker_cpu: Optional[int] = None,
worker_memory: Optional[int] = None,
scheduler_class: Optional[str] = None,
scheduler_options: Optional[dict] = None,
scheduler_cpu: Optional[int] = None,
scheduler_memory: Optional[int] = None,
account: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
dask_config: Optional[Dict] = None,
private_to_creator: Optional[bool] = None,
scheduler_vm_types: Optional[list] = None,
worker_gpu_type: Optional[str] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
backend_options: Optional[dict | BackendOptions] = None,
) -> Awaitable[int]:
...
def create_cluster(
self,
name: Optional[str] = None,
*,
software: Optional[str] = None,
worker_class: Optional[str] = None,
worker_options: Optional[dict] = None,
worker_cpu: Optional[int] = None,
worker_memory: Optional[int] = None,
scheduler_class: Optional[str] = None,
scheduler_options: Optional[dict] = None,
scheduler_cpu: Optional[int] = None,
scheduler_memory: Optional[int] = None,
account: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
private_to_creator: Optional[bool] = None,
dask_config: Optional[Dict] = None,
scheduler_vm_types: Optional[list] = None,
worker_gpu_type: Optional[str] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
backend_options: Optional[dict | BackendOptions] = None,
) -> Union[int, Awaitable[int]]:
return self._sync(
self._create_cluster,
name=name,
software_environment=software,
worker_class=worker_class,
worker_options=worker_options,
worker_cpu=worker_cpu,
worker_memory=worker_memory,
scheduler_options=scheduler_options,
scheduler_cpu=scheduler_cpu,
scheduler_memory=scheduler_memory,
account=account,
workers=workers,
environ=environ,
tags=tags,
dask_config=dask_config,
private_to_creator=private_to_creator,
scheduler_vm_types=scheduler_vm_types,
worker_vm_types=worker_vm_types,
gcp_worker_gpu_type=worker_gpu_type,
worker_disk_size=worker_disk_size,
backend_options=backend_options,
)
@track_context
async def _delete_cluster(
self, cluster_id: int, account: Optional[str] = None
) -> None:
account = account or self.default_account
route = f"/api/v2/clusters/account/{account}/id/{cluster_id}"
response = await self._do_request(
"DELETE",
self.server + route,
)
if response.status >= 400:
await handle_api_exception(response)
else:
# multiple deletes sometimes fail if we don't await response here
await response.json()
logger.info(f"Cluster {cluster_id} deleted successfully.")
@overload
def delete_cluster(
self: Cloud[Sync], cluster_id: int, account: Optional[str] = None
) -> None:
...
@overload
def delete_cluster(
self: Cloud[Async], cluster_id: int, account: Optional[str] = None
) -> Awaitable[None]:
...
@delete_docstring # TODO: this docstring erroneously says "Name of cluster" when it really accepts an ID
def delete_cluster(
self, cluster_id: int, account: Optional[str] = None
) -> Optional[Awaitable[None]]:
return self._sync(self._delete_cluster, cluster_id, account)
async def _get_cluster_details(
self, cluster_id: int, account: Optional[str] = None
):
account = account or self.default_account
r = await self._do_request_idempotent(
"GET", self.server + f"/api/v2/clusters/account/{account}/id/{cluster_id}"
)
if r.status >= 400:
await handle_api_exception(r)
return await r.json()
def _get_cluster_details_synced(
self, cluster_id: int, account: Optional[str] = None
):
return self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
account=account,
)
def _cluster_grafana_url(self, cluster_id: int, account: Optional[str] = None):
"""for internal Coiled use"""
account = account or self.default_account
details = self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
account=account,
)
return get_grafana_url(details, account=account, cluster_id=cluster_id)
def cluster_details(self, cluster_id: int, account: Optional[str] = None):
details = self._sync(
self._get_cluster_details,
cluster_id=cluster_id,
account=account,
)
state_keys = ["state", "reason", "updated"]
def get_state(state: dict):
return {k: v for k, v in state.items() if k in state_keys}
def get_instance(instance):
if instance is None:
return None
else:
return {
"id": instance["id"],
"created": instance["created"],
"name": instance["name"],
"public_ip_address": instance["public_ip_address"],
"private_ip_address": instance["private_ip_address"],
"current_state": get_state(instance["current_state"]),
}
def get_process(process: dict):
if process is None:
return None
else:
return {
"created": process["created"],
"name": process["name"],
"current_state": get_state(process["current_state"]),
"instance": get_instance(process["instance"]),
}
return {
"id": details["id"],
"name": details["name"],
"workers": [get_process(w) for w in details["workers"]],
"scheduler": get_process(details["scheduler"]),
"current_state": get_state(details["current_state"]),
"created": details["created"],
}
async def _get_workers_page(
self, cluster_id: int, page: int, account: Optional[str] = None
) -> Tuple[list, bool]:
page_size = 100
account = account or self.default_account
response = await self._do_request(
"GET",
self.server + f"/api/v2/workers/account/{account}/cluster/{cluster_id}/",
params={"limit": page_size, "offset": page_size * page},
)
if response.status >= 400:
await handle_api_exception(response)
results = await response.json()
has_more_pages = len(results) > 0
return results, has_more_pages
@track_context
async def _get_worker_names(
self,
account: str,
cluster_id: int,
statuses: Optional[List[ProcessStateEnum]] = None,
) -> Set[str]:
worker_infos = await self._depaginate_list(
self._get_workers_page, cluster_id=cluster_id, account=account
)
logger.debug(f"workers: {worker_infos}")
return {
w["name"]
for w in worker_infos
if statuses is None or w["current_state"]["state"] in statuses
}
@track_context
async def _security(self, cluster_id: int, account: Optional[str] = None):
cluster = await self._get_cluster_details(
cluster_id=cluster_id, account=account
)
if (
ProcessStateEnum(cluster["scheduler"]["current_state"]["state"])
!= ProcessStateEnum.started
):
raise RuntimeError(
f"Cannot get security info for cluster {cluster_id} scheduler is ready"
)
public_ip = cluster["scheduler"]["instance"]["public_ip_address"]
private_ip = cluster["scheduler"]["instance"]["private_ip_address"]
tls_cert = cluster["cluster_options"]["tls_cert"]
tls_key = cluster["cluster_options"]["tls_key"]
scheduler_port = cluster["scheduler_port"]
dashboard_address = cluster["scheduler"]["dashboard_address"]
# TODO (Declarative): pass extra_conn_args if we care about proxying through Coiled to the scheduler
security = GatewaySecurity(tls_key, tls_cert)
return security, {
"private_address": f"tls://{private_ip}:{scheduler_port}",
"public_address": f"tls://{public_ip}:{scheduler_port}",
"dashboard_address": dashboard_address,
}
@track_context
async def _requested_workers(
self, cluster_id: int, account: Optional[str] = None
) -> Set[str]:
raise NotImplementedError("TODO")
@overload
def requested_workers(
self: Cloud[Sync], cluster_id: int, account: Optional[str] = None
) -> Set[str]:
...
@track_context
async def _get_cluster_by_name(
self, name: str, account: Optional[str] = None
) -> int:
account, name = self._normalize_name(
name, context_account=account, allow_uppercase=True
)
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{account}/name/{name}",
)
if response.status == 404:
raise DoesNotExist
elif response.status >= 400:
await handle_api_exception(response)
cluster = await response.json()
return cluster["id"]
@overload
def get_cluster_by_name(
self: Cloud[Sync],
name: str,
account: Optional[str] = None,
) -> int:
...
@overload
def get_cluster_by_name(
self: Cloud[Async],
name: str,
account: Optional[str] = None,
) -> Awaitable[int]:
...
def get_cluster_by_name(
self,
name: str,
account: Optional[str] = None,
) -> Union[int, Awaitable[int]]:
return self._sync(
self._get_cluster_by_name,
name=name,
account=account,
)
@track_context
async def _cluster_status(
self,
cluster_id: int,
account: Optional[str] = None,
exclude_stopped: bool = True,
) -> dict:
raise NotImplementedError("TODO?")
@track_context
async def _get_cluster_states_declarative(
self,
cluster_id: int,
account: Optional[str] = None,
start_time: Optional[datetime.datetime] = None,
) -> int:
account = account or self.default_account
params = (
{"start_time": start_time.isoformat()} if start_time is not None else {}
)
response = await self._do_request_idempotent(
"GET",
self.server + f"/api/v2/clusters/account/{account}/id/{cluster_id}/states",
params=params,
)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def get_cluster_states(
self,
cluster_id: int,
account: Optional[str] = None,
start_time: Optional[datetime.datetime] = None,
) -> Union[int, Awaitable[int]]:
return self._sync(
self._get_cluster_states_declarative,
cluster_id=cluster_id,
account=account,
start_time=start_time,
)
def get_clusters_by_name(
self,
name: str,
account: Optional[str] = None,
) -> List[dict]:
"""Get all clusters matching name."""
return self._sync(
self._get_clusters_by_name,
name=name,
account=account,
)
@track_context
async def _get_clusters_by_name(
self, name: str, account: Optional[str] = None
) -> List[dict]:
account, name = self._normalize_name(
name, context_account=account, allow_uppercase=True
)
response = await self._do_request(
"GET",
self.server + f"/api/v2/clusters/account/{account}",
params={"name": name},
)
if response.status == 404:
raise DoesNotExist
elif response.status >= 400:
await handle_api_exception(response)
cluster = await response.json()
return cluster
@overload
def cluster_logs(
self,
cluster_id: int,
account: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Logs:
...
@overload
def cluster_logs(
self,
cluster_id: int,
account: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Awaitable[Logs]:
...
@track_context
async def _cluster_logs(
self,
cluster_id: int,
account: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Logs:
def is_errored(process):
process_state, instance_state = get_process_instance_state(process)
return (
process_state == ProcessStateEnum.error
or instance_state == InstanceStateEnum.error
)
account = account or self.default_account
# hits endpoint in order to get scheduler and worker instance names
cluster_info = await self._get_cluster_details(
cluster_id=cluster_id, account=account
)
try:
scheduler_name = cluster_info["scheduler"]["instance"]["name"]
except (TypeError, KeyError):
# no scheduler instance name in cluster info
logger.warning(
"No scheduler found when attempting to retrieve cluster logs."
)
scheduler_name = None
worker_names = [
worker["instance"]["name"]
for worker in cluster_info["workers"]
if worker["instance"] and (not errors_only or is_errored(worker))
]
LabeledInstance = namedtuple("LabeledInstance", ("name", "label"))
instances = []
if (
scheduler
and scheduler_name
and (not errors_only or is_errored(cluster_info["scheduler"]))
):
instances.append(LabeledInstance(scheduler_name, "Scheduler"))
if workers and worker_names:
instances.extend(
[
LabeledInstance(worker_name, worker_name)
for worker_name in worker_names
]
)
async def instance_log_with_semaphor(semaphor, **kwargs):
async with semaphor:
return await self._instance_logs(**kwargs)
# only get 100 logs at a time; the limit here is redundant since aiohttp session already limits concurrent
# connections but let's be safe just in case
semaphor = asyncio.Semaphore(value=100)
results = await asyncio.gather(
*[
instance_log_with_semaphor(
semaphor=semaphor, account=account, instance_name=inst.name
)
for inst in instances
]
)
out = {
instance_label: instance_log
for (_, instance_label), instance_log in zip(instances, results)
if len(instance_log)
}
return Logs(out)
def cluster_logs(
self,
cluster_id: int,
account: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
) -> Union[Logs, Awaitable[Logs]]:
return self._sync(
self._cluster_logs,
cluster_id=cluster_id,
account=account,
scheduler=scheduler,
workers=workers,
errors_only=errors_only,
)
async def _instance_logs(self, account: str, instance_name: str, safe=True) -> Log:
response = await self._do_request(
"GET",
self.server
+ "/api/v2/instances/{}/instance/{}/logs".format(account, instance_name),
)
if response.status >= 400:
if safe:
logger.warning(f"Error retrieving logs for {instance_name}")
return Log()
await handle_api_exception(response)
data = await response.json()
messages = "\n".join(logline.get("message", "") for logline in data)
return Log(messages)
@overload
def requested_workers(
self: Cloud[Async], cluster_id: int, account: Optional[str] = None
) -> Awaitable[Set[str]]:
...
def requested_workers(
self, cluster_id: int, account: Optional[str] = None
) -> Union[Set[str], Awaitable[Set[str]],]:
return self._sync(self._requested_workers, cluster_id, account)
@overload
def scale_up(
self: Cloud[Sync], cluster_id: int, n: int, account: Optional[str] = None
) -> Optional[Dict]:
...
@overload
def scale_up(
self: Cloud[Async], cluster_id: int, n: int, account: Optional[str] = None
) -> Awaitable[Optional[Dict]]:
...
def scale_up(
self, cluster_id: int, n: int, account: Optional[str] = None
) -> Union[Optional[Dict], Awaitable[Optional[Dict]]]:
"""Scale cluster to ``n`` workers
Parameters
----------
cluster_id
Unique cluster identifier.
n
Number of workers to scale cluster size to.
account
Name of Coiled account which the cluster belongs to.
If not provided, will default to ``Cloud.default_account``.
"""
return self._sync(self._scale_up, cluster_id, n, account)
@overload
def scale_down(
self: Cloud[Sync],
cluster_id: int,
workers: Set[str],
account: Optional[str] = None,
) -> None:
...
@overload
def scale_down(
self: Cloud[Async],
cluster_id: int,
workers: Set[str],
account: Optional[str] = None,
) -> Awaitable[None]:
...
def scale_down(
self, cluster_id: int, workers: Set[str], account: Optional[str] = None
) -> Optional[Awaitable[None]]:
"""Scale cluster to ``n`` workers
Parameters
----------
cluster_id
Unique cluster identifier.
workers
Set of workers to scale down to.
account
Name of Coiled account which the cluster belongs to.
If not provided, will default to ``Cloud.default_account``.
"""
return self._sync(self._scale_down, cluster_id, workers, account)
@track_context
async def _better_cluster_logs(
self,
cluster_id: int,
account: Optional[str] = None,
instance_ids: Optional[List[int]] = None,
dask: bool = False,
system: bool = False,
since_ms: Optional[int] = None,
until_ms: Optional[int] = None,
):
account = account or self.default_account
url_params = []
if dask:
url_params.append("dask=True")
if system:
url_params.append("system=True")
if since_ms:
url_params.append(f"since_ms={since_ms}")
if until_ms:
url_params.append(f"until_ms={until_ms}")
if instance_ids:
id_list = ",".join(map(str, instance_ids))
url_params.append(f"instance_ids={id_list}")
url_path = f"/api/v2/clusters/account/{account}/id/{cluster_id}/better-logs"
url_param_string = f"?{'&'.join(url_params)}" if url_params else ""
response = await self._do_request(
"GET",
f"{self.server}{url_path}{url_param_string}",
)
if response.status >= 400:
await handle_api_exception(response)
data = await response.json()
return data
def better_cluster_logs(
self,
cluster_id: int,
account: Optional[str] = None,
instance_ids: Optional[List[int]] = None,
dask: bool = False,
system: bool = False,
since_ms: Optional[int] = None,
until_ms: Optional[int] = None,
) -> Logs:
return self._sync(
self._better_cluster_logs,
cluster_id=cluster_id,
account=account,
instance_ids=instance_ids,
dask=dask,
system=system,
since_ms=since_ms,
until_ms=until_ms,
)
@track_context
async def _scale_up(
self, cluster_id: int, n: int, account: Optional[str] = None
) -> Dict:
"""
Increases the number of workers by ``n``.
"""
account = account or self.default_account
response = await self._do_request(
"POST",
f"{self.server}/api/v2/workers/account/{account}/cluster/{cluster_id}/",
json={"n_workers": n},
)
if response.status >= 400:
await handle_api_exception(response)
workers_info = await response.json()
return {"workers": {w["name"] for w in workers_info}}
@track_context
async def _scale_down(
self, cluster_id: int, workers: Set[str], account: Optional[str] = None
) -> None:
pass
account = account or self.default_account
response = await self._do_request(
"DELETE",
f"{self.server}/api/v2/workers/account/{account}/cluster/{cluster_id}/",
params={"name": workers},
)
if response.status >= 400:
await handle_api_exception(response)
@overload
def security(
self: Cloud[Sync], cluster_id: int, account: Optional[str] = None
) -> Tuple[dask.distributed.Security, dict]:
...
@overload
def security(
self: Cloud[Async], cluster_id: int, account: Optional[str] = None
) -> Awaitable[Tuple[dask.distributed.Security, dict]]:
...
def security(
self, cluster_id: int, account: Optional[str] = None
) -> Union[
Tuple[dask.distributed.Security, dict],
Awaitable[Tuple[dask.distribued.Security, dict]],
]:
return self._sync(self._security, cluster_id, account)
@track_context
async def _fetch_package_levels(self) -> List[PackageLevel]:
pass
response = await self._do_request(
"GET",
f"{self.server}/api/v2/packages/",
)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def get_ssh_key(
self,
cluster_id: int,
account: Optional[str] = None,
worker: Optional[str] = None,
) -> dict:
return self._sync(
self._get_ssh_key,
cluster_id=cluster_id,
account=account,
worker=worker,
)
@track_context
async def _get_ssh_key(
self, cluster_id: int, account: str, worker: Optional[str]
) -> dict:
account = account or self.default_account
route = f"/api/v2/clusters/account/{account}/id/{cluster_id}/ssh-key"
url = f"{self.server}{route}"
response = await self._do_request(
"GET", url, params={"worker": worker} if worker else None
)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def get_cluster_log_info(
self,
cluster_id: int,
account: Optional[str] = None,
) -> dict:
return self._sync(
self._get_cluster_log_info,
cluster_id=cluster_id,
account=account,
)
@track_context
async def _get_cluster_log_info(
self,
cluster_id: int,
account: str,
) -> dict:
account = account or self.default_account
route = f"/api/v2/clusters/account/{account}/id/{cluster_id}/log-info"
url = f"{self.server}{route}"
response = await self._do_request("GET", url)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
def approximate_packages(self, package: List[ApproximatePackageRequest]):
return self._sync(self._approximate_packages, package)
@track_context
async def _approximate_packages(
self, packages: List[ApproximatePackageRequest]
) -> List[ApproximatePackageResult]:
response = await self._do_request(
"POST",
f"{self.server}/api/v2/software-environment/approximate-packages",
json=packages,
)
if response.status >= 400:
await handle_api_exception(response)
return await response.json()
Cloud = CloudBeta
[docs]def cluster_logs(
cluster_id: int,
account: Optional[str] = None,
scheduler: bool = True,
workers: bool = True,
errors_only: bool = False,
):
"""
Returns cluster logs as a dictionary, with a key for the scheduler and each worker.
.. versionchanged:: 0.2.0
``cluster_name`` is no longer accepted, use ``cluster_id`` instead.
"""
with Cloud() as cloud:
return cloud.cluster_logs(cluster_id, account, scheduler, workers, errors_only)
def better_cluster_logs(
cluster_id: int,
account: Optional[str] = None,
instance_ids: Optional[List[int]] = None,
dask: bool = False,
system: bool = False,
since_ms: Optional[int] = None,
until_ms: Optional[int] = None,
):
"""
Pull logs for the cluster using better endpoint.
Logs for recent clusters are split between system and container (dask), you can get
either or both (or none).
since_ms and until_ms are both inclusive (you'll get logs with timestamp matching those).
"""
with Cloud() as cloud:
return cloud.better_cluster_logs(
cluster_id,
account,
instance_ids=instance_ids,
dask=dask,
system=system,
since_ms=since_ms,
until_ms=until_ms,
)
def cluster_details(
cluster_id: int,
account: Optional[str] = None,
) -> dict:
"""
Get details of a cluster as a dictionary.
"""
with CloudBeta() as cloud:
return cloud.cluster_details(
cluster_id=cluster_id,
account=account,
)
def log_cluster_debug_info(
cluster_id: int,
account: Optional[str] = None,
):
with CloudBeta() as cloud:
details = cloud.cluster_details(cluster_id, account)
logger.debug("Cluster details:")
logger.debug(json.dumps(details, indent=2))
states_by_type = cloud.get_cluster_states(cluster_id, account)
logger.debug("cluster state history:")
log_states(flatten_log_states(states_by_type), level=logging.DEBUG)
# log the scheduler logs (if errored), and up to 1 errored worker
instance_logs = cloud.cluster_logs(cluster_id, account, errors_only=True)
logger.debug("Finding errored scheduler instance log:")
try:
logger.debug(instance_logs.pop("Scheduler"))
except KeyError:
logger.debug("Did not find any errored scheduler instance logs.")
logger.debug("Finding errored worker instance log:")
try:
worker_log = next(iter(instance_logs.values()))
logger.debug(worker_log)
except StopIteration:
logger.debug("Did not find any errored worker instance logs.")
def create_cluster(
name: Optional[str] = None,
*,
software: Optional[str] = None,
worker_options: Optional[dict] = None,
worker_cpu: Optional[int] = None,
worker_memory: Optional[int] = None,
scheduler_options: Optional[dict] = None,
scheduler_cpu: Optional[int] = None,
scheduler_memory: Optional[int] = None,
account: Optional[str] = None,
workers: int = 0,
environ: Optional[Dict] = None,
tags: Optional[Dict] = None,
dask_config: Optional[Dict] = None,
private_to_creator: Optional[bool] = None,
scheduler_vm_types: Optional[list] = None,
worker_vm_types: Optional[list] = None,
worker_disk_size: Optional[int] = None,
backend_options: Optional[dict | BackendOptions] = None,
) -> int:
"""Create a cluster
Parameters
---------
name
Name of cluster.
software
Identifier of the software environment to use, in the format (<account>/)<name>. If the software environment
is owned by the same account as that passed into "account", the (<account>/) prefix is optional.
For example, suppose your account is "wondercorp", but your friends at "friendlycorp" have an environment
named "xgboost" that you want to use; you can specify this with "friendlycorp/xgboost". If you simply
entered "xgboost", this is shorthand for "wondercorp/xgboost".
The "name" portion of (<account>/)<name> can only contain ASCII letters, hyphens and underscores.
worker_cpu
Number of CPUs allocated for each worker. Defaults to 2.
worker_memory
Amount of memory to allocate for each worker. Defaults to 8 GiB.
worker_options
Mapping with keyword arguments to pass to ``worker_class``. Defaults to ``{}``.
scheduler_cpu
Number of CPUs allocated for the scheduler. Defaults to 1.
scheduler_memory
Amount of memory to allocate for the scheduler. Defaults to 4 GiB.
scheduler_options
Mapping with keyword arguments to pass to ``scheduler_class``. Defaults to ``{}``.
account
Name of the Coiled account to create the cluster in.
If not provided, will default to ``Cloud.default_account``.
workers
Number of workers we to launch.
environ
Dictionary of environment variables.
tags
Dictionary of instance tags
dask_config
Dictionary of dask config to put on cluster
See Also
--------
coiled.Cluster
"""
with CloudBeta(account=account) as cloud:
return cloud.create_cluster(
name=name,
software=software,
worker_options=worker_options,
worker_cpu=worker_cpu,
worker_memory=worker_memory,
scheduler_options=scheduler_options,
scheduler_cpu=scheduler_cpu,
scheduler_memory=scheduler_memory,
account=account,
workers=workers,
environ=environ,
tags=tags,
dask_config=dask_config,
private_to_creator=private_to_creator,
backend_options=backend_options,
worker_vm_types=worker_vm_types,
worker_disk_size=worker_disk_size,
scheduler_vm_types=scheduler_vm_types,
)
[docs]@list_docstring
def list_clusters(account=None, max_pages: Optional[int] = None):
with CloudBeta() as cloud:
return cloud.list_clusters(account=account, max_pages=max_pages)
[docs]@delete_docstring
def delete_cluster(name: str, account: Optional[str] = None):
with CloudBeta() as cloud:
cluster_id = cloud.get_cluster_by_name(name=name, account=account)
if cluster_id is not None:
return cloud.delete_cluster(cluster_id=cluster_id, account=account)