Source code for coiled._beta.core

from __future__ import annotations

import asyncio
import base64
import datetime
import json
import logging
import sys
import weakref
from collections import namedtuple
from hashlib import md5
from pathlib import Path
from typing import (
    Awaitable,
    BinaryIO,
    Callable,
    Dict,
    Generic,
    List,
    NoReturn,
    Optional,
    Set,
    Tuple,
    Union,
    overload,
)

import backoff
import dask
from aiohttp import ClientResponseError, ClientSession, ContentTypeError

from coiled.types import (
    ApproximatePackageRequest,
    ApproximatePackageResult,
    PackageSchema,
    ResolvedPackageInfo,
)

if sys.version_info >= (3, 8):
    from typing import TypedDict
else:
    from typing_extensions import TypedDict

from dask.utils import parse_timedelta
from distributed.utils import Log, Logs

from coiled.cli.setup.entry import do_setup_wizard
from coiled.context import track_context
from coiled.core import Async
from coiled.core import Cloud as OldCloud
from coiled.core import IsAsynchronous, Sync, delete_docstring, list_docstring
from coiled.errors import ClusterCreationError, DoesNotExist, ServerError
from coiled.types import PackageLevel
from coiled.utils import (
    COILED_LOGGER_NAME,
    GatewaySecurity,
    get_grafana_url,
    validate_type,
)

from .states import (
    InstanceStateEnum,
    ProcessStateEnum,
    flatten_log_states,
    get_process_instance_state,
    log_states,
)

logger = logging.getLogger(COILED_LOGGER_NAME)


def setup_logging(level=logging.INFO):
    # only set up logging if there's no log level specified yet on the coiled logger
    if logging.getLogger(COILED_LOGGER_NAME).level == 0:
        logging.getLogger(COILED_LOGGER_NAME).setLevel(level)
        logging.basicConfig()


async def handle_api_exception(response, exception_cls=ServerError) -> NoReturn:
    try:
        error_body = await response.json()
    except ContentTypeError:
        raise exception_cls(
            f"Unexpected status code ({response.status}) to {response.method}:{response.url}, contact support@coiled.io"
        )
    if "message" in error_body:
        raise exception_cls(error_body["message"])
    if "detail" in error_body:
        raise exception_cls(error_body["detail"])
    raise exception_cls(error_body)


class FirewallOptions(TypedDict):
    """

    A dictionary with the following key/value pairs

    Parameters
    ----------
    ports
        List of ports to open to cidr on the scheduler.
        For example, ``[22, 8786]`` opens port 22 for SSH and 8786 for client to Dask connection.
    cidr
        CIDR block from which to allow access. For example ``0.0.0.0/0`` allows access from any IP address.
    """

    ports: List[int]
    cidr: str


[docs]class BackendOptions(TypedDict, total=False): """ A dictionary with the following key/value pairs Parameters ---------- region_name Region name to launch cluster in. For example: us-east-2 zone_name Zone name to launch cluster in. For example: us-east-2a firewall Allows you to specify firewall for scheduler; see :py:class:`FirewallOptions` for details. ingress Allows you to specify multiple CIDR blocks (and corresponding ports) to open for ingress on the scheduler firewall. spot Whether to request spot instances. spot_on_demand_fallback If requesting spot, whether to request non-spot instances if we get fewer spot instances than desired. multizone Tell the cloud provider to pick zone with best availability, we'll keep workers all in the same zone, scheduler may or may not be in that zone as well. """ region_name: Optional[str] zone_name: Optional[str] firewall: Optional[FirewallOptions] ingress: Optional[List[FirewallOptions]] spot: Optional[bool] spot_on_demand_fallback: Optional[bool] multizone: Optional[bool] use_dashboard_public_ip: Optional[bool] send_prometheus_metrics: Optional[bool] # TODO deprecate prometheus_write: Optional[dict]
class AWSOptions(BackendOptions, total=False): """ A dictionary with the following key/value pairs plus any pairs in :py:class:`BackendOptions` Parameters ---------- keypair_name AWS Keypair to assign worker/scheduler instances """ keypair_name: Optional[str] spot_replacement: Optional[bool] use_placement_group: Optional[bool] class GCPOptions(BackendOptions, total=False): scheduler_accelerator_count: Optional[int] scheduler_accelerator_type: Optional[str] worker_accelerator_count: Optional[int] worker_accelerator_type: Optional[str] BackendOptionTypes = [AWSOptions, GCPOptions] class CloudBeta(OldCloud, Generic[IsAsynchronous]): _recent_sync: list[weakref.ReferenceType[CloudBeta[Sync]]] = list() _recent_async: list[weakref.ReferenceType[CloudBeta[Async]]] = list() # just overriding to get the right signature (CloudBeta, not Cloud) def __enter__(self: CloudBeta[Sync]) -> CloudBeta[Sync]: return self def __exit__(self: CloudBeta[Sync], typ, value, tb) -> None: self.close() async def __aenter__(self: CloudBeta[Async]) -> CloudBeta[Async]: return await self._start() async def __aexit__(self: CloudBeta[Async], typ, value, tb) -> None: await self._close() # these overloads are necessary for the typechecker to know that we really have a CloudBeta, not a Cloud # without them, CloudBeta.current would be typed to return a Cloud # # https://www.python.org/dev/peps/pep-0673/ would remove the need for this. # That PEP also mentions a workaround with type vars, which doesn't work for us because type vars aren't # subscribtable @overload @classmethod def current(cls, asynchronous: Sync) -> CloudBeta[Sync]: ... @overload @classmethod def current(cls, asynchronous: Async) -> CloudBeta[Async]: ... @overload @classmethod def current(cls, asynchronous: bool) -> CloudBeta: ... @classmethod def current(cls, asynchronous: bool) -> CloudBeta: recent: list[weakref.ReferenceType[CloudBeta]] if asynchronous: recent = cls._recent_async else: recent = cls._recent_sync try: cloud = recent[-1]() while cloud is None or cloud.status != "running": recent.pop() cloud = recent[-1]() except IndexError: if asynchronous: return cls(asynchronous=True) else: return cls(asynchronous=False) else: return cloud @track_context async def _get_default_instance_types( self, provider: str, gpu: bool = False ) -> List[str]: if provider == "aws": if gpu: return ["g4dn.xlarge"] else: return ["t3.xlarge"] elif provider == "gcp": if gpu: # n1-standard-8 with 30GB of memory might be best, but that's big for a default return ["n1-standard-4"] else: return ["e2-standard-4"] else: raise ValueError( f"unexpected provider {provider}; cannot determine default instance types" ) async def _list_dask_scheduler_page( self, page: int, account: Optional[str] = None, since: Optional[str] = "7 days", user: Optional[str] = None, ) -> Tuple[list, bool]: page_size = 100 account = account or self.default_account kwargs = {} if since: kwargs["since"] = parse_timedelta(since) if user: kwargs["user"] = user response = await self._do_request( "GET", self.server + f"/api/v2/analytics/{account}/clusters/list", params={ "limit": page_size, "offset": page_size * page, **kwargs, }, ) if response.status >= 400: await handle_api_exception(response) results = await response.json() has_more_pages = len(results) > 0 return results, has_more_pages @track_context async def _list_dask_scheduler( self, account: Optional[str] = None, since: Optional[str] = "7 days", user: Optional[str] = None, ): return await self._depaginate_list( self._list_dask_scheduler_page, account=account, since=since, user=user, ) @overload def list_dask_scheduler( self: Cloud[Sync], account: Optional[str] = None, since: Optional[str] = "7 days", user: Optional[str] = None, ) -> list: ... @overload def list_dask_scheduler( self: Cloud[Async], account: Optional[str] = None, since: Optional[str] = "7 days", user: Optional[str] = "", ) -> Awaitable[list]: ... def list_dask_scheduler( self, account: Optional[str] = None, since: Optional[str] = "7 days", user: Optional[str] = "", ) -> Union[list, Awaitable[list]]: return self._sync(self._list_dask_scheduler, account, since=since, user=user) async def _list_computations(self, cluster_id: int, account: Optional[str] = None): return await self._depaginate_list( self._list_computations_page, cluster_id=cluster_id, account=account ) async def _list_computations_page( self, page: int, cluster_id: int, account: Optional[str] = None, ) -> Tuple[list, bool]: page_size = 100 account = account or self.default_account response = await self._do_request( "GET", self.server + f"/api/v2/analytics/{account}/{cluster_id}/computations/list", params={"limit": page_size, "offset": page_size * page}, ) if response.status >= 400: await handle_api_exception(response) results = await response.json() has_more_pages = len(results) > 0 return results, has_more_pages @overload def list_computations( self: Cloud[Sync], cluster_id: int, account: Optional[str] = None ) -> list: ... @overload def list_computations( self: Cloud[Async], cluster_id: int, account: Optional[str] = None ) -> Awaitable[list]: ... def list_computations( self, cluster_id: int, account: Optional[str] = None ) -> Union[list, Awaitable[list]]: return self._sync(self._list_computations, cluster_id, account) @overload def list_exceptions( self, cluster_id: Optional[int] = None, scheduler_id: Optional[int] = None, account: Optional[str] = None, since: Optional[str] = None, user: Optional[str] = None, ) -> list: ... @overload def list_exceptions( self, cluster_id: Optional[int] = None, scheduler_id: Optional[int] = None, account: Optional[str] = None, since: Optional[str] = None, user: Optional[str] = None, ) -> Awaitable[list]: ... def list_exceptions( self, cluster_id: Optional[int] = None, scheduler_id: Optional[int] = None, account: Optional[str] = None, since: Optional[str] = None, user: Optional[str] = None, ) -> Union[list, Awaitable[list]]: return self._sync( self._list_exceptions, cluster_id=cluster_id, scheduler_id=scheduler_id, account=account, since=since, user=user, ) async def _list_exceptions( self, cluster_id: Optional[int] = None, scheduler_id: Optional[int] = None, account: Optional[str] = None, since: Optional[str] = None, user: Optional[str] = None, ): return await self._depaginate_list( self._list_exceptions_page, cluster_id=cluster_id, scheduler_id=scheduler_id, account=account, since=since, user=user, ) async def _list_exceptions_page( self, page: int, cluster_id: Optional[int] = None, scheduler_id: Optional[int] = None, account: Optional[str] = None, since: Optional[str] = None, user: Optional[str] = None, ) -> Tuple[list, bool]: page_size = 100 account = account or self.default_account kwargs = {} if since: kwargs["since"] = parse_timedelta(since) if user: kwargs["user"] = user if cluster_id: kwargs["cluster"] = cluster_id if scheduler_id: kwargs["scheduler"] = scheduler_id response = await self._do_request( "GET", self.server + f"/api/v2/analytics/{account}/exceptions/list", params={"limit": page_size, "offset": page_size * page, **kwargs}, ) if response.status >= 400: await handle_api_exception(response) results = await response.json() has_more_pages = len(results) > 0 return results, has_more_pages async def _list_events_page( self, page: int, cluster_id: int, account: Optional[str] = None, ) -> Tuple[list, bool]: page_size = 100 account = account or self.default_account response = await self._do_request( "GET", self.server + f"/api/v2/analytics/{account}/{cluster_id}/events/list", params={"limit": page_size, "offset": page_size * page}, ) if response.status >= 400: await handle_api_exception(response) results = await response.json() has_more_pages = len(results) > 0 return results, has_more_pages async def _list_events(self, cluster_id: int, account: Optional[str] = None): return await self._depaginate_list( self._list_events_page, cluster_id=cluster_id, account=account ) def list_events( self, cluster_id: int, account: Optional[str] = None ) -> Union[list, Awaitable[list]]: return self._sync(self._list_events, cluster_id, account) async def _send_state( self, cluster_id: int, desired_status: str, account: Optional[str] = None ): account = account or self.default_account response = await self._do_request( "POST", self.server + f"/api/v2/analytics/{account}/{cluster_id}/desired-state", json={"desired_status": desired_status}, ) if response.status >= 400: await handle_api_exception(response) def send_state( self, cluster_id: int, desired_status: str, account: Optional[str] = None ) -> Union[None, Awaitable[None]]: return self._sync(self._send_state, cluster_id, desired_status, account) @track_context async def _list_clusters( self, account: Optional[str] = None, max_pages: Optional[int] = None ): return await self._depaginate_list( self._list_clusters_page, account=account, max_pages=max_pages ) @overload def list_clusters( self: Cloud[Sync], account: Optional[str] = None, max_pages: Optional[int] = None, ) -> list: ... @overload def list_clusters( self: Cloud[Async], account: Optional[str] = None, max_pages: Optional[int] = None, ) -> Awaitable[list]: ... @list_docstring def list_clusters( self, account: Optional[str] = None, max_pages: Optional[int] = None ) -> Union[list, Awaitable[list]]: return self._sync(self._list_clusters, account, max_pages=max_pages) async def _list_clusters_page( self, page: int, account: Optional[str] = None ) -> Tuple[list, bool]: page_size = 100 account = account or self.default_account response = await self._do_request( "GET", self.server + f"/api/v2/clusters/account/{account}/", params={"limit": page_size, "offset": page_size * page}, ) if response.status >= 400: await handle_api_exception(response) results = await response.json() has_more_pages = len(results) > 0 return results, has_more_pages @staticmethod async def _depaginate_list( func: Callable[..., Awaitable[Tuple[list, bool]]], max_pages: Optional[int] = None, *args, **kwargs, ) -> list: results_all = [] page = 0 while True: kwargs["page"] = page results, next = await func(*args, **kwargs) results_all += results page += 1 if (not results) or next is None: break # page is the number of pages we've already fetched (since 0-indexed) if max_pages and page >= max_pages: break return results_all async def _create_package_sync_env( self, packages: List[ResolvedPackageInfo], account: Optional[str] = None ) -> int: account = account or self.default_account prepared_packages: List[PackageSchema] = [] for pkg in packages: if pkg["sdist"]: file_id = await self._create_senv_package( pkg["sdist"], contents_md5=pkg["md5"], account=account ) else: file_id = None prepared_packages.append( { "name": pkg["name"], "source": pkg["source"], "channel": pkg["channel"], "conda_name": pkg["conda_name"], "specifier": pkg["specifier"], "include": pkg["include"], "client_version": pkg["client_version"], "file": file_id, } ) return await self._create_software_environment_v2( senv=prepared_packages, account=account ) @track_context async def _create_senv_package( self, package_file: BinaryIO, contents_md5: str, account: Optional[str] = None ) -> int: logger.info(f"Starting upload for {package_file}") package_data = package_file.read() # s3 expects the md5 to be base64 encoded wheel_md5 = base64.b64encode(md5(package_data).digest()).decode("utf-8") account = account or self.default_account response = await self._do_request( "POST", self.server + f"/api/v2/software-environment/account/{account}/package-upload", json={ "name": Path(package_file.name).name, "md5": contents_md5, "wheel_md5": wheel_md5, }, ) if response.status >= 400: await handle_api_exception(response) # always raises exception, no return data = await response.json() if data["should_upload"]: await self._put_package( url=data["upload_url"], package_data=package_data, file_md5=wheel_md5, ) else: logger.info(f"{package_file} MD5 matches existing, skipping upload") return data["id"] @backoff.on_exception( backoff.expo, ClientResponseError, max_time=120, giveup=lambda error: error.status < 500, ) async def _put_package(self, url: str, package_data: bytes, file_md5: str): # can't use the default session as it has coiled auth headers async with ClientSession() as session: async with session.put( url=url, data=package_data, headers={"content-md5": file_md5} ) as resp: resp.raise_for_status() @track_context async def _create_software_environment_v2( self, senv: List[PackageSchema], account: Optional[str] = None, ) -> int: account = account or self.default_account resp = await self._do_request( "POST", self.server + f"/api/v2/software-environment/account/{account}", json={ "packages": senv, "md5": md5( json.dumps(senv, sort_keys=True).encode("utf-8") ).hexdigest(), }, ) if resp.status >= 400: await handle_api_exception(resp) # always raises exception, no return data = await resp.json() return data["id"] @track_context async def _create_cluster( self, # todo: make name optional and pick one for them, like pre-declarative? # https://gitlab.com/coiled/cloud/-/issues/4305 name: str, *, software_environment: Optional[str] = None, senv_v2_id: Optional[int] = None, worker_class: Optional[str] = None, worker_options: Optional[dict] = None, worker_cpu: Optional[int] = None, worker_memory: Optional[Union[str, List[str]]] = None, scheduler_class: Optional[str] = None, scheduler_options: Optional[dict] = None, scheduler_cpu: Optional[int] = None, scheduler_memory: Optional[Union[str, List[str]]] = None, account: Optional[str] = None, workers: int = 0, environ: Optional[Dict] = None, tags: Optional[Dict] = None, dask_config: Optional[Dict] = None, scheduler_vm_types: Optional[list] = None, gcp_worker_gpu_type: Optional[str] = None, gcp_worker_gpu_count: Optional[int] = None, worker_vm_types: Optional[list] = None, worker_disk_size: Optional[int] = None, backend_options: Optional[Union[AWSOptions, GCPOptions, dict]] = None, use_scheduler_public_ip: Optional[bool] = None, private_to_creator: Optional[bool] = None, ) -> int: # TODO (Declarative): support these args, or decide not to # https://gitlab.com/coiled/cloud/-/issues/4305 if scheduler_class is not None: raise ValueError("scheduler_class is not supported in beta/new Coiled yet") account = account or self.default_account account, name = self._normalize_name( name, context_account=account, allow_uppercase=True, ) self._verify_account(account) data = { "name": name, "workers": workers, "worker_instance_types": worker_vm_types, "scheduler_instance_types": scheduler_vm_types, "software_environment": software_environment, "worker_options": worker_options, "worker_cpu": worker_cpu, "worker_class": worker_class, "worker_memory": worker_memory, "worker_disk_size": worker_disk_size, "scheduler_options": scheduler_options, "scheduler_cpu": scheduler_cpu, "scheduler_memory": scheduler_memory, "environ": environ, "tags": tags, "dask_config": dask_config, "private_to_creator": private_to_creator, "env_id": senv_v2_id # "jupyter_on_scheduler": True, } if gcp_worker_gpu_type is not None: # for backwards compatibility with v1 options backend_options = backend_options if backend_options else {} backend_options = { **backend_options, "worker_accelerator_count": gcp_worker_gpu_count or 1, "worker_accelerator_type": gcp_worker_gpu_type, } elif gcp_worker_gpu_count: # not ideal but v1 only supported T4 and `worker_gpu=1` would give you one backend_options = backend_options if backend_options else {} backend_options = { **backend_options, "worker_accelerator_count": gcp_worker_gpu_count, "worker_accelerator_type": "nvidia-tesla-t4", } if use_scheduler_public_ip is False: backend_options = backend_options if backend_options else {} if "use_dashboard_public_ip" not in backend_options: backend_options["use_dashboard_public_ip"] = False if backend_options: # for backwards compatibility with v1 options if "region" in backend_options and "region_name" not in backend_options: backend_options["region_name"] = backend_options["region"] # type: ignore del backend_options["region"] # type: ignore if "zone" in backend_options and "zone_name" not in backend_options: backend_options["zone_name"] = backend_options["zone"] # type: ignore del backend_options["zone"] # type: ignore # firewall just lets you specify a single CIDR block to open for ingress # we want to support a list of ingress CIDR blocks if "firewall" in backend_options: backend_options["ingress"] = [backend_options.pop("firewall")] # type: ignore # validate against TypedDicts -- should be better (especially better errors) if not any((validate_type(t, backend_options) for t in BackendOptionTypes)): raise ValueError( "backend_options should be an instance of coiled.BackendOptions" ) # convert the list of ingress rules to the FirewallSpec expected server-side if "ingress" in backend_options: fw_spec = {"ingress": backend_options.pop("ingress")} backend_options["firewall_spec"] = fw_spec # type: ignore data["options"] = backend_options response = await self._do_request( "POST", self.server + f"/api/v2/clusters/account/{account}/", json=data, ) response_json = await response.json() if response.status >= 400: from .widgets import EXECUTION_CONTEXT if response_json.get("code") == "NO_CLOUD_SETUP": server_error_message = response_json.get("message") error_message = ( f"{server_error_message} or by running `coiled setup wizard`" ) if EXECUTION_CONTEXT == "terminal": # maybe not interactive so just raise raise ClusterCreationError(error_message) else: # interactive session so let's try running the cloud setup wizard if do_setup_wizard(): # the user setup their cloud backend, so let's try creating cluster again! response = await self._do_request( "POST", self.server + f"/api/v2/clusters/account/{account}/", json=data, ) if response.status >= 400: await handle_api_exception( response ) # always raises exception, no return response_json = await response.json() else: raise ClusterCreationError(error_message) else: if "message" in response_json: raise ServerError(response_json["message"]) if "detail" in response_json: raise ServerError(response_json["detail"]) raise ServerError(response_json) return response_json["id"] @overload def create_cluster( self: Cloud[Sync], name: Optional[str] = None, *, software: Optional[str] = None, worker_class: Optional[str] = None, worker_options: Optional[dict] = None, worker_cpu: Optional[int] = None, worker_memory: Optional[int] = None, scheduler_class: Optional[str] = None, scheduler_options: Optional[dict] = None, scheduler_cpu: Optional[int] = None, scheduler_memory: Optional[int] = None, account: Optional[str] = None, workers: int = 0, environ: Optional[Dict] = None, tags: Optional[Dict] = None, dask_config: Optional[Dict] = None, private_to_creator: Optional[bool] = None, scheduler_vm_types: Optional[list] = None, worker_gpu_type: Optional[str] = None, worker_vm_types: Optional[list] = None, worker_disk_size: Optional[int] = None, backend_options: Optional[dict | BackendOptions] = None, ) -> int: ... @overload def create_cluster( self: Cloud[Async], name: Optional[str] = None, *, software: Optional[str] = None, worker_class: Optional[str] = None, worker_options: Optional[dict] = None, worker_cpu: Optional[int] = None, worker_memory: Optional[int] = None, scheduler_class: Optional[str] = None, scheduler_options: Optional[dict] = None, scheduler_cpu: Optional[int] = None, scheduler_memory: Optional[int] = None, account: Optional[str] = None, workers: int = 0, environ: Optional[Dict] = None, tags: Optional[Dict] = None, dask_config: Optional[Dict] = None, private_to_creator: Optional[bool] = None, scheduler_vm_types: Optional[list] = None, worker_gpu_type: Optional[str] = None, worker_vm_types: Optional[list] = None, worker_disk_size: Optional[int] = None, backend_options: Optional[dict | BackendOptions] = None, ) -> Awaitable[int]: ... def create_cluster( self, name: Optional[str] = None, *, software: Optional[str] = None, worker_class: Optional[str] = None, worker_options: Optional[dict] = None, worker_cpu: Optional[int] = None, worker_memory: Optional[int] = None, scheduler_class: Optional[str] = None, scheduler_options: Optional[dict] = None, scheduler_cpu: Optional[int] = None, scheduler_memory: Optional[int] = None, account: Optional[str] = None, workers: int = 0, environ: Optional[Dict] = None, tags: Optional[Dict] = None, private_to_creator: Optional[bool] = None, dask_config: Optional[Dict] = None, scheduler_vm_types: Optional[list] = None, worker_gpu_type: Optional[str] = None, worker_vm_types: Optional[list] = None, worker_disk_size: Optional[int] = None, backend_options: Optional[dict | BackendOptions] = None, ) -> Union[int, Awaitable[int]]: return self._sync( self._create_cluster, name=name, software_environment=software, worker_class=worker_class, worker_options=worker_options, worker_cpu=worker_cpu, worker_memory=worker_memory, scheduler_options=scheduler_options, scheduler_cpu=scheduler_cpu, scheduler_memory=scheduler_memory, account=account, workers=workers, environ=environ, tags=tags, dask_config=dask_config, private_to_creator=private_to_creator, scheduler_vm_types=scheduler_vm_types, worker_vm_types=worker_vm_types, gcp_worker_gpu_type=worker_gpu_type, worker_disk_size=worker_disk_size, backend_options=backend_options, ) @track_context async def _delete_cluster( self, cluster_id: int, account: Optional[str] = None ) -> None: account = account or self.default_account route = f"/api/v2/clusters/account/{account}/id/{cluster_id}" response = await self._do_request( "DELETE", self.server + route, ) if response.status >= 400: await handle_api_exception(response) else: # multiple deletes sometimes fail if we don't await response here await response.json() logger.info(f"Cluster {cluster_id} deleted successfully.") @overload def delete_cluster( self: Cloud[Sync], cluster_id: int, account: Optional[str] = None ) -> None: ... @overload def delete_cluster( self: Cloud[Async], cluster_id: int, account: Optional[str] = None ) -> Awaitable[None]: ... @delete_docstring # TODO: this docstring erroneously says "Name of cluster" when it really accepts an ID def delete_cluster( self, cluster_id: int, account: Optional[str] = None ) -> Optional[Awaitable[None]]: return self._sync(self._delete_cluster, cluster_id, account) async def _get_cluster_details( self, cluster_id: int, account: Optional[str] = None ): account = account or self.default_account r = await self._do_request_idempotent( "GET", self.server + f"/api/v2/clusters/account/{account}/id/{cluster_id}" ) if r.status >= 400: await handle_api_exception(r) return await r.json() def _get_cluster_details_synced( self, cluster_id: int, account: Optional[str] = None ): return self._sync( self._get_cluster_details, cluster_id=cluster_id, account=account, ) def _cluster_grafana_url(self, cluster_id: int, account: Optional[str] = None): """for internal Coiled use""" account = account or self.default_account details = self._sync( self._get_cluster_details, cluster_id=cluster_id, account=account, ) return get_grafana_url(details, account=account, cluster_id=cluster_id) def cluster_details(self, cluster_id: int, account: Optional[str] = None): details = self._sync( self._get_cluster_details, cluster_id=cluster_id, account=account, ) state_keys = ["state", "reason", "updated"] def get_state(state: dict): return {k: v for k, v in state.items() if k in state_keys} def get_instance(instance): if instance is None: return None else: return { "id": instance["id"], "created": instance["created"], "name": instance["name"], "public_ip_address": instance["public_ip_address"], "private_ip_address": instance["private_ip_address"], "current_state": get_state(instance["current_state"]), } def get_process(process: dict): if process is None: return None else: return { "created": process["created"], "name": process["name"], "current_state": get_state(process["current_state"]), "instance": get_instance(process["instance"]), } return { "id": details["id"], "name": details["name"], "workers": [get_process(w) for w in details["workers"]], "scheduler": get_process(details["scheduler"]), "current_state": get_state(details["current_state"]), "created": details["created"], } async def _get_workers_page( self, cluster_id: int, page: int, account: Optional[str] = None ) -> Tuple[list, bool]: page_size = 100 account = account or self.default_account response = await self._do_request( "GET", self.server + f"/api/v2/workers/account/{account}/cluster/{cluster_id}/", params={"limit": page_size, "offset": page_size * page}, ) if response.status >= 400: await handle_api_exception(response) results = await response.json() has_more_pages = len(results) > 0 return results, has_more_pages @track_context async def _get_worker_names( self, account: str, cluster_id: int, statuses: Optional[List[ProcessStateEnum]] = None, ) -> Set[str]: worker_infos = await self._depaginate_list( self._get_workers_page, cluster_id=cluster_id, account=account ) logger.debug(f"workers: {worker_infos}") return { w["name"] for w in worker_infos if statuses is None or w["current_state"]["state"] in statuses } @track_context async def _security(self, cluster_id: int, account: Optional[str] = None): cluster = await self._get_cluster_details( cluster_id=cluster_id, account=account ) if ( ProcessStateEnum(cluster["scheduler"]["current_state"]["state"]) != ProcessStateEnum.started ): raise RuntimeError( f"Cannot get security info for cluster {cluster_id} scheduler is ready" ) public_ip = cluster["scheduler"]["instance"]["public_ip_address"] private_ip = cluster["scheduler"]["instance"]["private_ip_address"] tls_cert = cluster["cluster_options"]["tls_cert"] tls_key = cluster["cluster_options"]["tls_key"] scheduler_port = cluster["scheduler_port"] dashboard_address = cluster["scheduler"]["dashboard_address"] # TODO (Declarative): pass extra_conn_args if we care about proxying through Coiled to the scheduler security = GatewaySecurity(tls_key, tls_cert) return security, { "private_address": f"tls://{private_ip}:{scheduler_port}", "public_address": f"tls://{public_ip}:{scheduler_port}", "dashboard_address": dashboard_address, } @track_context async def _requested_workers( self, cluster_id: int, account: Optional[str] = None ) -> Set[str]: raise NotImplementedError("TODO") @overload def requested_workers( self: Cloud[Sync], cluster_id: int, account: Optional[str] = None ) -> Set[str]: ... @track_context async def _get_cluster_by_name( self, name: str, account: Optional[str] = None ) -> int: account, name = self._normalize_name( name, context_account=account, allow_uppercase=True ) response = await self._do_request( "GET", self.server + f"/api/v2/clusters/account/{account}/name/{name}", ) if response.status == 404: raise DoesNotExist elif response.status >= 400: await handle_api_exception(response) cluster = await response.json() return cluster["id"] @overload def get_cluster_by_name( self: Cloud[Sync], name: str, account: Optional[str] = None, ) -> int: ... @overload def get_cluster_by_name( self: Cloud[Async], name: str, account: Optional[str] = None, ) -> Awaitable[int]: ... def get_cluster_by_name( self, name: str, account: Optional[str] = None, ) -> Union[int, Awaitable[int]]: return self._sync( self._get_cluster_by_name, name=name, account=account, ) @track_context async def _cluster_status( self, cluster_id: int, account: Optional[str] = None, exclude_stopped: bool = True, ) -> dict: raise NotImplementedError("TODO?") @track_context async def _get_cluster_states_declarative( self, cluster_id: int, account: Optional[str] = None, start_time: Optional[datetime.datetime] = None, ) -> int: account = account or self.default_account params = ( {"start_time": start_time.isoformat()} if start_time is not None else {} ) response = await self._do_request_idempotent( "GET", self.server + f"/api/v2/clusters/account/{account}/id/{cluster_id}/states", params=params, ) if response.status >= 400: await handle_api_exception(response) return await response.json() def get_cluster_states( self, cluster_id: int, account: Optional[str] = None, start_time: Optional[datetime.datetime] = None, ) -> Union[int, Awaitable[int]]: return self._sync( self._get_cluster_states_declarative, cluster_id=cluster_id, account=account, start_time=start_time, ) def get_clusters_by_name( self, name: str, account: Optional[str] = None, ) -> List[dict]: """Get all clusters matching name.""" return self._sync( self._get_clusters_by_name, name=name, account=account, ) @track_context async def _get_clusters_by_name( self, name: str, account: Optional[str] = None ) -> List[dict]: account, name = self._normalize_name( name, context_account=account, allow_uppercase=True ) response = await self._do_request( "GET", self.server + f"/api/v2/clusters/account/{account}", params={"name": name}, ) if response.status == 404: raise DoesNotExist elif response.status >= 400: await handle_api_exception(response) cluster = await response.json() return cluster @overload def cluster_logs( self, cluster_id: int, account: Optional[str] = None, scheduler: bool = True, workers: bool = True, errors_only: bool = False, ) -> Logs: ... @overload def cluster_logs( self, cluster_id: int, account: Optional[str] = None, scheduler: bool = True, workers: bool = True, errors_only: bool = False, ) -> Awaitable[Logs]: ... @track_context async def _cluster_logs( self, cluster_id: int, account: Optional[str] = None, scheduler: bool = True, workers: bool = True, errors_only: bool = False, ) -> Logs: def is_errored(process): process_state, instance_state = get_process_instance_state(process) return ( process_state == ProcessStateEnum.error or instance_state == InstanceStateEnum.error ) account = account or self.default_account # hits endpoint in order to get scheduler and worker instance names cluster_info = await self._get_cluster_details( cluster_id=cluster_id, account=account ) try: scheduler_name = cluster_info["scheduler"]["instance"]["name"] except (TypeError, KeyError): # no scheduler instance name in cluster info logger.warning( "No scheduler found when attempting to retrieve cluster logs." ) scheduler_name = None worker_names = [ worker["instance"]["name"] for worker in cluster_info["workers"] if worker["instance"] and (not errors_only or is_errored(worker)) ] LabeledInstance = namedtuple("LabeledInstance", ("name", "label")) instances = [] if ( scheduler and scheduler_name and (not errors_only or is_errored(cluster_info["scheduler"])) ): instances.append(LabeledInstance(scheduler_name, "Scheduler")) if workers and worker_names: instances.extend( [ LabeledInstance(worker_name, worker_name) for worker_name in worker_names ] ) async def instance_log_with_semaphor(semaphor, **kwargs): async with semaphor: return await self._instance_logs(**kwargs) # only get 100 logs at a time; the limit here is redundant since aiohttp session already limits concurrent # connections but let's be safe just in case semaphor = asyncio.Semaphore(value=100) results = await asyncio.gather( *[ instance_log_with_semaphor( semaphor=semaphor, account=account, instance_name=inst.name ) for inst in instances ] ) out = { instance_label: instance_log for (_, instance_label), instance_log in zip(instances, results) if len(instance_log) } return Logs(out) def cluster_logs( self, cluster_id: int, account: Optional[str] = None, scheduler: bool = True, workers: bool = True, errors_only: bool = False, ) -> Union[Logs, Awaitable[Logs]]: return self._sync( self._cluster_logs, cluster_id=cluster_id, account=account, scheduler=scheduler, workers=workers, errors_only=errors_only, ) async def _instance_logs(self, account: str, instance_name: str, safe=True) -> Log: response = await self._do_request( "GET", self.server + "/api/v2/instances/{}/instance/{}/logs".format(account, instance_name), ) if response.status >= 400: if safe: logger.warning(f"Error retrieving logs for {instance_name}") return Log() await handle_api_exception(response) data = await response.json() messages = "\n".join(logline.get("message", "") for logline in data) return Log(messages) @overload def requested_workers( self: Cloud[Async], cluster_id: int, account: Optional[str] = None ) -> Awaitable[Set[str]]: ... def requested_workers( self, cluster_id: int, account: Optional[str] = None ) -> Union[Set[str], Awaitable[Set[str]],]: return self._sync(self._requested_workers, cluster_id, account) @overload def scale_up( self: Cloud[Sync], cluster_id: int, n: int, account: Optional[str] = None ) -> Optional[Dict]: ... @overload def scale_up( self: Cloud[Async], cluster_id: int, n: int, account: Optional[str] = None ) -> Awaitable[Optional[Dict]]: ... def scale_up( self, cluster_id: int, n: int, account: Optional[str] = None ) -> Union[Optional[Dict], Awaitable[Optional[Dict]]]: """Scale cluster to ``n`` workers Parameters ---------- cluster_id Unique cluster identifier. n Number of workers to scale cluster size to. account Name of Coiled account which the cluster belongs to. If not provided, will default to ``Cloud.default_account``. """ return self._sync(self._scale_up, cluster_id, n, account) @overload def scale_down( self: Cloud[Sync], cluster_id: int, workers: Set[str], account: Optional[str] = None, ) -> None: ... @overload def scale_down( self: Cloud[Async], cluster_id: int, workers: Set[str], account: Optional[str] = None, ) -> Awaitable[None]: ... def scale_down( self, cluster_id: int, workers: Set[str], account: Optional[str] = None ) -> Optional[Awaitable[None]]: """Scale cluster to ``n`` workers Parameters ---------- cluster_id Unique cluster identifier. workers Set of workers to scale down to. account Name of Coiled account which the cluster belongs to. If not provided, will default to ``Cloud.default_account``. """ return self._sync(self._scale_down, cluster_id, workers, account) @track_context async def _better_cluster_logs( self, cluster_id: int, account: Optional[str] = None, instance_ids: Optional[List[int]] = None, dask: bool = False, system: bool = False, since_ms: Optional[int] = None, until_ms: Optional[int] = None, ): account = account or self.default_account url_params = [] if dask: url_params.append("dask=True") if system: url_params.append("system=True") if since_ms: url_params.append(f"since_ms={since_ms}") if until_ms: url_params.append(f"until_ms={until_ms}") if instance_ids: id_list = ",".join(map(str, instance_ids)) url_params.append(f"instance_ids={id_list}") url_path = f"/api/v2/clusters/account/{account}/id/{cluster_id}/better-logs" url_param_string = f"?{'&'.join(url_params)}" if url_params else "" response = await self._do_request( "GET", f"{self.server}{url_path}{url_param_string}", ) if response.status >= 400: await handle_api_exception(response) data = await response.json() return data def better_cluster_logs( self, cluster_id: int, account: Optional[str] = None, instance_ids: Optional[List[int]] = None, dask: bool = False, system: bool = False, since_ms: Optional[int] = None, until_ms: Optional[int] = None, ) -> Logs: return self._sync( self._better_cluster_logs, cluster_id=cluster_id, account=account, instance_ids=instance_ids, dask=dask, system=system, since_ms=since_ms, until_ms=until_ms, ) @track_context async def _scale_up( self, cluster_id: int, n: int, account: Optional[str] = None ) -> Dict: """ Increases the number of workers by ``n``. """ account = account or self.default_account response = await self._do_request( "POST", f"{self.server}/api/v2/workers/account/{account}/cluster/{cluster_id}/", json={"n_workers": n}, ) if response.status >= 400: await handle_api_exception(response) workers_info = await response.json() return {"workers": {w["name"] for w in workers_info}} @track_context async def _scale_down( self, cluster_id: int, workers: Set[str], account: Optional[str] = None ) -> None: pass account = account or self.default_account response = await self._do_request( "DELETE", f"{self.server}/api/v2/workers/account/{account}/cluster/{cluster_id}/", params={"name": workers}, ) if response.status >= 400: await handle_api_exception(response) @overload def security( self: Cloud[Sync], cluster_id: int, account: Optional[str] = None ) -> Tuple[dask.distributed.Security, dict]: ... @overload def security( self: Cloud[Async], cluster_id: int, account: Optional[str] = None ) -> Awaitable[Tuple[dask.distributed.Security, dict]]: ... def security( self, cluster_id: int, account: Optional[str] = None ) -> Union[ Tuple[dask.distributed.Security, dict], Awaitable[Tuple[dask.distribued.Security, dict]], ]: return self._sync(self._security, cluster_id, account) @track_context async def _fetch_package_levels(self) -> List[PackageLevel]: pass response = await self._do_request( "GET", f"{self.server}/api/v2/packages/", ) if response.status >= 400: await handle_api_exception(response) return await response.json() def get_ssh_key( self, cluster_id: int, account: Optional[str] = None, worker: Optional[str] = None, ) -> dict: return self._sync( self._get_ssh_key, cluster_id=cluster_id, account=account, worker=worker, ) @track_context async def _get_ssh_key( self, cluster_id: int, account: str, worker: Optional[str] ) -> dict: account = account or self.default_account route = f"/api/v2/clusters/account/{account}/id/{cluster_id}/ssh-key" url = f"{self.server}{route}" response = await self._do_request( "GET", url, params={"worker": worker} if worker else None ) if response.status >= 400: await handle_api_exception(response) return await response.json() def get_cluster_log_info( self, cluster_id: int, account: Optional[str] = None, ) -> dict: return self._sync( self._get_cluster_log_info, cluster_id=cluster_id, account=account, ) @track_context async def _get_cluster_log_info( self, cluster_id: int, account: str, ) -> dict: account = account or self.default_account route = f"/api/v2/clusters/account/{account}/id/{cluster_id}/log-info" url = f"{self.server}{route}" response = await self._do_request("GET", url) if response.status >= 400: await handle_api_exception(response) return await response.json() def approximate_packages(self, package: List[ApproximatePackageRequest]): return self._sync(self._approximate_packages, package) @track_context async def _approximate_packages( self, packages: List[ApproximatePackageRequest] ) -> List[ApproximatePackageResult]: response = await self._do_request( "POST", f"{self.server}/api/v2/software-environment/approximate-packages", json=packages, ) if response.status >= 400: await handle_api_exception(response) return await response.json() Cloud = CloudBeta
[docs]def cluster_logs( cluster_id: int, account: Optional[str] = None, scheduler: bool = True, workers: bool = True, errors_only: bool = False, ): """ Returns cluster logs as a dictionary, with a key for the scheduler and each worker. .. versionchanged:: 0.2.0 ``cluster_name`` is no longer accepted, use ``cluster_id`` instead. """ with Cloud() as cloud: return cloud.cluster_logs(cluster_id, account, scheduler, workers, errors_only)
def better_cluster_logs( cluster_id: int, account: Optional[str] = None, instance_ids: Optional[List[int]] = None, dask: bool = False, system: bool = False, since_ms: Optional[int] = None, until_ms: Optional[int] = None, ): """ Pull logs for the cluster using better endpoint. Logs for recent clusters are split between system and container (dask), you can get either or both (or none). since_ms and until_ms are both inclusive (you'll get logs with timestamp matching those). """ with Cloud() as cloud: return cloud.better_cluster_logs( cluster_id, account, instance_ids=instance_ids, dask=dask, system=system, since_ms=since_ms, until_ms=until_ms, ) def cluster_details( cluster_id: int, account: Optional[str] = None, ) -> dict: """ Get details of a cluster as a dictionary. """ with CloudBeta() as cloud: return cloud.cluster_details( cluster_id=cluster_id, account=account, ) def log_cluster_debug_info( cluster_id: int, account: Optional[str] = None, ): with CloudBeta() as cloud: details = cloud.cluster_details(cluster_id, account) logger.debug("Cluster details:") logger.debug(json.dumps(details, indent=2)) states_by_type = cloud.get_cluster_states(cluster_id, account) logger.debug("cluster state history:") log_states(flatten_log_states(states_by_type), level=logging.DEBUG) # log the scheduler logs (if errored), and up to 1 errored worker instance_logs = cloud.cluster_logs(cluster_id, account, errors_only=True) logger.debug("Finding errored scheduler instance log:") try: logger.debug(instance_logs.pop("Scheduler")) except KeyError: logger.debug("Did not find any errored scheduler instance logs.") logger.debug("Finding errored worker instance log:") try: worker_log = next(iter(instance_logs.values())) logger.debug(worker_log) except StopIteration: logger.debug("Did not find any errored worker instance logs.") def create_cluster( name: Optional[str] = None, *, software: Optional[str] = None, worker_options: Optional[dict] = None, worker_cpu: Optional[int] = None, worker_memory: Optional[int] = None, scheduler_options: Optional[dict] = None, scheduler_cpu: Optional[int] = None, scheduler_memory: Optional[int] = None, account: Optional[str] = None, workers: int = 0, environ: Optional[Dict] = None, tags: Optional[Dict] = None, dask_config: Optional[Dict] = None, private_to_creator: Optional[bool] = None, scheduler_vm_types: Optional[list] = None, worker_vm_types: Optional[list] = None, worker_disk_size: Optional[int] = None, backend_options: Optional[dict | BackendOptions] = None, ) -> int: """Create a cluster Parameters --------- name Name of cluster. software Identifier of the software environment to use, in the format (<account>/)<name>. If the software environment is owned by the same account as that passed into "account", the (<account>/) prefix is optional. For example, suppose your account is "wondercorp", but your friends at "friendlycorp" have an environment named "xgboost" that you want to use; you can specify this with "friendlycorp/xgboost". If you simply entered "xgboost", this is shorthand for "wondercorp/xgboost". The "name" portion of (<account>/)<name> can only contain ASCII letters, hyphens and underscores. worker_cpu Number of CPUs allocated for each worker. Defaults to 2. worker_memory Amount of memory to allocate for each worker. Defaults to 8 GiB. worker_options Mapping with keyword arguments to pass to ``worker_class``. Defaults to ``{}``. scheduler_cpu Number of CPUs allocated for the scheduler. Defaults to 1. scheduler_memory Amount of memory to allocate for the scheduler. Defaults to 4 GiB. scheduler_options Mapping with keyword arguments to pass to ``scheduler_class``. Defaults to ``{}``. account Name of the Coiled account to create the cluster in. If not provided, will default to ``Cloud.default_account``. workers Number of workers we to launch. environ Dictionary of environment variables. tags Dictionary of instance tags dask_config Dictionary of dask config to put on cluster See Also -------- coiled.Cluster """ with CloudBeta(account=account) as cloud: return cloud.create_cluster( name=name, software=software, worker_options=worker_options, worker_cpu=worker_cpu, worker_memory=worker_memory, scheduler_options=scheduler_options, scheduler_cpu=scheduler_cpu, scheduler_memory=scheduler_memory, account=account, workers=workers, environ=environ, tags=tags, dask_config=dask_config, private_to_creator=private_to_creator, backend_options=backend_options, worker_vm_types=worker_vm_types, worker_disk_size=worker_disk_size, scheduler_vm_types=scheduler_vm_types, )
[docs]@list_docstring def list_clusters(account=None, max_pages: Optional[int] = None): with CloudBeta() as cloud: return cloud.list_clusters(account=account, max_pages=max_pages)
[docs]@delete_docstring def delete_cluster(name: str, account: Optional[str] = None): with CloudBeta() as cloud: cluster_id = cloud.get_cluster_by_name(name=name, account=account) if cluster_id is not None: return cloud.delete_cluster(cluster_id=cluster_id, account=account)