Source code for coiled.v2.core

from __future__ import annotations

import asyncio
import datetime
import json
import logging
import weakref
from collections import namedtuple
from typing import (
    Awaitable,
    Callable,
    Dict,
    Generic,
    List,
    NoReturn,
    Optional,
    Set,
    Tuple,
    Union,
    overload,
)

import dask.config
import dask.distributed
from aiohttp import ContentTypeError
from dask.utils import parse_timedelta
from distributed.utils import Log, Logs
from rich.progress import Progress
from typing_extensions import TypeAlias

from coiled.cli.setup.entry import do_setup_wizard
from coiled.context import track_context
from coiled.core import Async, IsAsynchronous, Sync, delete_docstring, list_docstring
from coiled.core import Cloud as OldCloud
from coiled.errors import ClusterCreationError, DoesNotExist, ServerError
from coiled.types import (
    ArchitectureTypesEnum,
    AWSOptions,
    GCPOptions,
    PackageLevel,
    PackageSchema,
    ResolvedPackageInfo,
    SoftwareEnvironmentAlias,
)
from coiled.utils import (
    COILED_LOGGER_NAME,
    GatewaySecurity,
    get_grafana_url,
    validate_backend_options,
)

from .states import (
    InstanceStateEnum,
    ProcessStateEnum,
    flatten_log_states,
    get_process_instance_state,
    log_states,
)
from .widgets.util import simple_progress

logger = logging.getLogger(COILED_LOGGER_NAME)


def setup_logging(level=logging.INFO):
    # We want to be able to give info-level messages to users.
    # For users who haven't set up a log handler, this requires creating one (b/c the handler of "last resort,
    # logging.lastResort, has a level of "warning".
    #
    # Conservatively, we only do anything here if the user hasn't set up any log handlers on the root logger
    # or the Coiled logger. If they have any handler, we assume logging is configured how they want it.
    coiled_logger = logging.getLogger(COILED_LOGGER_NAME)
    root_logger = logging.getLogger()
    if coiled_logger.handlers == [] and root_logger.handlers == []:
        stream_handler = logging.StreamHandler()
        stream_handler.setFormatter(logging.Formatter(fmt="[%(asctime)s][%(levelname)-8s][%(name)s] %(message)s"))
        # Conservatively, only change the Coiled logger level there's no log level specified yet.
        if coiled_logger.level == 0:
            coiled_logger.setLevel(level)
            coiled_logger.addHandler(stream_handler)


async def handle_api_exception(response, exception_cls=ServerError) -> NoReturn:
    try:
        error_body = await response.json()
    except ContentTypeError:
        raise exception_cls(
            f"Unexpected status code ({response.status}) to {response.method}:{response.url}, contact support@coiled.io"
        ) from None
    if "message" in error_body:
        raise exception_cls(error_body["message"])
    if "detail" in error_body:
        raise exception_cls(error_body["detail"])
    raise exception_cls(error_body)


CloudV2SyncAsync: TypeAlias = Union["CloudV2[Async]", "CloudV2[Sync]"]


class CloudV2(OldCloud, Generic[IsAsynchronous]):
    _recent_sync: List[weakref.ReferenceType[CloudV2[Sync]]] = list()
    _recent_async: List[weakref.ReferenceType[CloudV2[Async]]] = list()

    # just overriding to get the right signature (CloudV2, not Cloud)
    def __enter__(self: CloudV2[Sync]) -> CloudV2[Sync]:
        return self

    def __exit__(self: CloudV2[Sync], typ, value, tb) -> None:
        self.close()

    async def __aenter__(self: CloudV2[Async]) -> CloudV2[Async]:
        return await self._start()

    async def __aexit__(self: CloudV2[Async], typ, value, tb) -> None:
        await self._close()

    # these overloads are necessary for the typechecker to know that we really have a CloudV2, not a Cloud
    # without them, CloudV2.current would be typed to return a Cloud
    #
    # https://www.python.org/dev/peps/pep-0673/ would remove the need for this.
    # That PEP also mentions a workaround with type vars, which doesn't work for us because type vars aren't
    # subscribtable
    @overload
    @classmethod
    def current(cls, asynchronous: Sync) -> CloudV2[Sync]: ...

    @overload
    @classmethod
    def current(cls, asynchronous: Async) -> CloudV2[Async]: ...

    @overload
    @classmethod
    def current(cls, asynchronous: bool) -> CloudV2: ...

    @classmethod
    def current(cls, asynchronous: bool) -> CloudV2:
        recent: List[weakref.ReferenceType[CloudV2]]
        if asynchronous:
            recent = cls._recent_async
        else:
            recent = cls._recent_sync
        try:
            cloud = recent[-1]()
            while cloud is None or cloud.status != "running":
                recent.pop()
                cloud = recent[-1]()
        except IndexError:
            if asynchronous:
                return cls(asynchronous=True)
            else:
                return cls(asynchronous=False)
        else:
            return cloud

    @track_context
    async def _get_default_instance_types(self, provider: str, gpu: bool = False, arch: str = "x86_64") -> List[str]:
        if arch not in ("arm64", "x86_64"):
            raise ValueError(f"arch '{arch}' is not supported for default instance types")
        if provider == "aws":
            if arch == "arm64":
                if gpu:
                    return ["g5g.xlarge"]  # has NVIDIA T4G
                else:
                    return ["m7g.xlarge", "m6g.xlarge"]
            if gpu:
                return ["g4dn.xlarge"]
            else:
                return ["m6i.xlarge", "m5.xlarge"]
        elif provider == "gcp":
            if arch != "x86_64":
                return ["t2a-standard-4"]
            if gpu:
                # n1-standard-8 with 30GB of memory might be best, but that's big for a default
                return ["n1-standard-4"]
            else:
                return ["e2-standard-4"]
        elif provider == "azure":
            if arch != "x86_64":
                raise ValueError(f"no default instance type for Azure with {arch} architecture")
            if gpu:
                raise ValueError("no default GPU instance type for Azure")
            return ["Standard_D4_v5"]
        else:
            raise ValueError(f"unexpected provider {provider}; cannot determine default instance types")

    async def _list_dask_scheduler_page(
        self,
        page: int,
        workspace: Optional[str] = None,
        since: Optional[str] = "7 days",
        user: Optional[str] = None,
    ) -> Tuple[list, bool]:
        page_size = 100
        workspace = workspace or self.default_workspace
        kwargs = {}
        if since:
            kwargs["since"] = parse_timedelta(since)
        if user:
            kwargs["user"] = user
        response = await self._do_request(
            "GET",
            self.server + f"/api/v2/analytics/{workspace}/clusters/list",
            params={
                "limit": page_size,
                "offset": page_size * page,
                **kwargs,
            },
        )
        if response.status >= 400:
            await handle_api_exception(response)

        results = await response.json()
        has_more_pages = len(results) > 0
        return results, has_more_pages

    @track_context
    async def _list_dask_scheduler(
        self,
        workspace: Optional[str] = None,
        since: Optional[str] = "7 days",
        user: Optional[str] = None,
    ):
        return await self._depaginate_list(
            self._list_dask_scheduler_page,
            workspace=workspace,
            since=since,
            user=user,
        )

    @overload
    def list_dask_scheduler(
        self: Cloud[Sync],
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        since: Optional[str] = "7 days",
        user: Optional[str] = None,
    ) -> list: ...

    @overload
    def list_dask_scheduler(
        self: Cloud[Async],
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        since: Optional[str] = "7 days",
        user: Optional[str] = "",
    ) -> Awaitable[list]: ...

    def list_dask_scheduler(
        self,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        since: Optional[str] = "7 days",
        user: Optional[str] = "",
    ) -> Union[list, Awaitable[list]]:
        return self._sync(self._list_dask_scheduler, workspace or account, since=since, user=user)

    async def _list_computations(
        self, cluster_id: Optional[int] = None, scheduler_id: Optional[int] = None, workspace: Optional[str] = None
    ):
        return await self._depaginate_list(
            self._list_computations_page, cluster_id=cluster_id, scheduler_id=scheduler_id, workspace=workspace
        )

    async def _list_computations_page(
        self,
        page: int,
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        workspace: Optional[str] = None,
    ) -> Tuple[list, bool]:
        page_size = 100
        workspace = workspace or self.default_workspace

        if not scheduler_id and not cluster_id:
            raise ValueError("either cluster_id or scheduler_id must be specified")

        api = (
            f"/api/v2/analytics/{workspace}/{scheduler_id}/computations/list"
            if scheduler_id
            else f"/api/v2/analytics/{workspace}/cluster/{cluster_id}/computations/list"
        )

        response = await self._do_request(
            "GET",
            self.server + api,
            params={"limit": page_size, "offset": page_size * page},
        )
        if response.status >= 400:
            await handle_api_exception(response)

        results = await response.json()
        has_more_pages = len(results) > 0
        return results, has_more_pages

    @overload
    def list_computations(
        self: Cloud[Sync],
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> list: ...

    @overload
    def list_computations(
        self: Cloud[Async],
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> Awaitable[list]: ...

    def list_computations(
        self,
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> Union[list, Awaitable[list]]:
        return self._sync(
            self._list_computations, cluster_id=cluster_id, scheduler_id=scheduler_id, workspace=workspace or account
        )

    @overload
    def list_exceptions(
        self,
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        since: Optional[str] = None,
        user: Optional[str] = None,
    ) -> list: ...

    @overload
    def list_exceptions(
        self,
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        since: Optional[str] = None,
        user: Optional[str] = None,
    ) -> Awaitable[list]: ...

    def list_exceptions(
        self,
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        since: Optional[str] = None,
        user: Optional[str] = None,
    ) -> Union[list, Awaitable[list]]:
        return self._sync(
            self._list_exceptions,
            cluster_id=cluster_id,
            scheduler_id=scheduler_id,
            workspace=workspace or account,
            since=since,
            user=user,
        )

    async def _list_exceptions(
        self,
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        workspace: Optional[str] = None,
        since: Optional[str] = None,
        user: Optional[str] = None,
    ):
        return await self._depaginate_list(
            self._list_exceptions_page,
            cluster_id=cluster_id,
            scheduler_id=scheduler_id,
            workspace=workspace,
            since=since,
            user=user,
        )

    async def _list_exceptions_page(
        self,
        page: int,
        cluster_id: Optional[int] = None,
        scheduler_id: Optional[int] = None,
        workspace: Optional[str] = None,
        since: Optional[str] = None,
        user: Optional[str] = None,
    ) -> Tuple[list, bool]:
        page_size = 100
        workspace = workspace or self.default_workspace
        kwargs = {}
        if since:
            kwargs["since"] = parse_timedelta(since)
        if user:
            kwargs["user"] = user
        if cluster_id:
            kwargs["cluster"] = cluster_id
        if scheduler_id:
            kwargs["scheduler"] = scheduler_id
        response = await self._do_request(
            "GET",
            self.server + f"/api/v2/analytics/{workspace}/exceptions/list",
            params={"limit": page_size, "offset": page_size * page, **kwargs},
        )
        if response.status >= 400:
            await handle_api_exception(response)

        results = await response.json()
        has_more_pages = len(results) > 0
        return results, has_more_pages

    async def _list_events_page(
        self,
        page: int,
        cluster_id: int,
        workspace: Optional[str] = None,
    ) -> Tuple[list, bool]:
        page_size = 100
        workspace = workspace or self.default_workspace
        response = await self._do_request(
            "GET",
            self.server + f"/api/v2/analytics/{workspace}/{cluster_id}/events/list",
            params={"limit": page_size, "offset": page_size * page},
        )
        if response.status >= 400:
            await handle_api_exception(response)

        results = await response.json()
        has_more_pages = len(results) > 0
        return results, has_more_pages

    async def _list_events(self, cluster_id: int, workspace: Optional[str] = None):
        return await self._depaginate_list(self._list_events_page, cluster_id=cluster_id, workspace=workspace)

    def list_events(
        self,
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> Union[list, Awaitable[list]]:
        return self._sync(self._list_events, cluster_id, workspace or account)

    async def _send_state(self, cluster_id: int, desired_status: str, workspace: Optional[str] = None):
        workspace = workspace or self.default_workspace
        response = await self._do_request(
            "POST",
            self.server + f"/api/v2/analytics/{workspace}/{cluster_id}/desired-state",
            json={"desired_status": desired_status},
        )
        if response.status >= 400:
            await handle_api_exception(response)

    def send_state(
        self,
        cluster_id: int,
        desired_status: str,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> Union[None, Awaitable[None]]:
        return self._sync(self._send_state, cluster_id, desired_status, workspace or account)

    @track_context
    async def _list_clusters(self, workspace: Optional[str] = None, max_pages: Optional[int] = None):
        return await self._depaginate_list(self._list_clusters_page, workspace=workspace, max_pages=max_pages)

    @overload
    def list_clusters(
        self: Cloud[Sync],
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        max_pages: Optional[int] = None,
    ) -> list: ...

    @overload
    def list_clusters(
        self: Cloud[Async],
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        max_pages: Optional[int] = None,
    ) -> Awaitable[list]: ...

    @list_docstring
    def list_clusters(
        self, account: Optional[str] = None, workspace: Optional[str] = None, max_pages: Optional[int] = None
    ) -> Union[list, Awaitable[list]]:
        return self._sync(self._list_clusters, workspace=workspace or account, max_pages=max_pages)

    async def _list_clusters_page(self, page: int, workspace: Optional[str] = None) -> Tuple[list, bool]:
        page_size = 100
        workspace = workspace or self.default_workspace
        response = await self._do_request(
            "GET",
            self.server + f"/api/v2/clusters/account/{workspace}/",
            params={"limit": page_size, "offset": page_size * page},
        )
        if response.status >= 400:
            await handle_api_exception(response)

        results = await response.json()
        has_more_pages = len(results) > 0
        return results, has_more_pages

    @staticmethod
    async def _depaginate_list(
        func: Callable[..., Awaitable[Tuple[list, bool]]],
        max_pages: Optional[int] = None,
        *args,
        **kwargs,
    ) -> list:
        results_all = []
        page = 0
        while True:
            kwargs["page"] = page
            results, next = await func(*args, **kwargs)
            results_all += results
            page += 1
            if (not results) or next is None:
                break
            # page is the number of pages we've already fetched (since 0-indexed)
            if max_pages and page >= max_pages:
                break
        return results_all

    @track_context
    async def _create_package_sync_env(
        self,
        packages: List[ResolvedPackageInfo],
        progress: Optional[Progress] = None,
        workspace: Optional[str] = None,
        gpu_enabled: bool = False,
        architecture: ArchitectureTypesEnum = ArchitectureTypesEnum.X86_64,
        region_name: Optional[str] = None,
    ) -> SoftwareEnvironmentAlias:
        workspace = workspace or self.default_workspace
        prepared_packages: List[PackageSchema] = []
        for pkg in packages:
            if pkg["sdist"] and pkg["md5"]:
                with simple_progress(f"Uploading {pkg['name']}", progress=progress):
                    file_id = await self._create_senv_package(
                        pkg["sdist"],
                        contents_md5=pkg["md5"],
                        workspace=workspace,
                        region_name=region_name,
                    )
            else:
                file_id = None
            prepared_packages.append({
                "name": pkg["name"],
                "source": pkg["source"],
                "channel": pkg["channel"],
                "conda_name": pkg["conda_name"],
                "specifier": pkg["specifier"],
                "include": pkg["include"],
                "client_version": pkg["client_version"],
                "file": file_id,
            })
        with simple_progress("Requesting package sync build", progress=progress):
            result = await self._create_software_environment_v2(
                senv={
                    "packages": prepared_packages,
                    "raw_pip": None,
                    "raw_conda": None,
                },
                workspace=workspace,
                architecture=architecture,
                gpu_enabled=gpu_enabled,
                region_name=region_name,
            )
        return result

    @track_context
    async def _create_custom_certificate(self, subdomain: str, workspace: Optional[str] = None):
        workspace = workspace or self.default_workspace
        response = await self._do_request(
            "POST",
            self.server + f"/api/v2/clusters/account/{workspace}/https-certificate",
            json={"subdomain": subdomain},
        )
        if response.status >= 400:
            await handle_api_exception(response)

    async def _check_custom_certificate(self, subdomain: str, workspace: Optional[str] = None):
        response = await self._do_request(
            "GET",
            self.server + f"/api/v2/clusters/account/{workspace}/https-certificate/{subdomain}",
        )
        if response.status >= 400:
            await handle_api_exception(response)
        response_json = await response.json()
        cert_status = response_json.get("status")
        return cert_status

    @track_context
    async def _create_cluster(
        self,
        # todo: make name optional and pick one for them, like pre-declarative?
        # https://gitlab.com/coiled/cloud/-/issues/4305
        name: str,
        *,
        software_environment: Optional[str] = None,
        senv_v2_id: Optional[int] = None,
        worker_class: Optional[str] = None,
        worker_options: Optional[dict] = None,
        scheduler_options: Optional[dict] = None,
        workspace: Optional[str] = None,
        workers: int = 0,
        environ: Optional[Dict] = None,
        tags: Optional[Dict] = None,
        dask_config: Optional[Dict] = None,
        scheduler_vm_types: Optional[list] = None,
        gcp_worker_gpu_type: Optional[str] = None,
        gcp_worker_gpu_count: Optional[int] = None,
        worker_vm_types: Optional[list] = None,
        worker_disk_size: Optional[int] = None,
        worker_disk_throughput: Optional[int] = None,
        scheduler_disk_size: Optional[int] = None,
        backend_options: Optional[Union[AWSOptions, GCPOptions, dict]] = None,
        use_scheduler_public_ip: Optional[bool] = None,
        use_dashboard_https: Optional[bool] = None,
        private_to_creator: Optional[bool] = None,
        extra_worker_on_scheduler: Optional[bool] = None,
        n_worker_specs_per_host: Optional[int] = None,
        custom_subdomain: Optional[str] = None,
    ) -> Tuple[int, bool]:
        # TODO (Declarative): support these args, or decide not to
        # https://gitlab.com/coiled/cloud/-/issues/4305

        workspace = workspace or self.default_workspace
        account, name = self._normalize_name(name, context_workspace=workspace, allow_uppercase=True)

        await self._verify_workspace(account)

        data = {
            "name": name,
            "workers": workers,
            "worker_instance_types": worker_vm_types,
            "scheduler_instance_types": scheduler_vm_types,
            "worker_options": worker_options,
            "worker_class": worker_class,
            "worker_disk_size": worker_disk_size,
            "worker_disk_throughput": worker_disk_throughput,
            "scheduler_disk_size": scheduler_disk_size,
            "scheduler_options": scheduler_options,
            "environ": environ,
            "tags": tags,
            "dask_config": dask_config,
            "private_to_creator": private_to_creator,
            "env_id": senv_v2_id,
            "env_name": software_environment,
            "extra_worker_on_scheduler": extra_worker_on_scheduler,
            "n_worker_specs_per_host": n_worker_specs_per_host,
            "use_aws_creds_endpoint": dask.config.get("coiled.use_aws_creds_endpoint", False),
            "custom_subdomain": custom_subdomain,
        }

        backend_options = backend_options if backend_options else {}

        if gcp_worker_gpu_type is not None:
            # for backwards compatibility with v1 options
            backend_options = {
                **backend_options,
                "worker_accelerator_count": gcp_worker_gpu_count or 1,
                "worker_accelerator_type": gcp_worker_gpu_type,
            }
        elif gcp_worker_gpu_count:
            # not ideal but v1 only supported T4 and `worker_gpu=1` would give you one
            backend_options = {
                **backend_options,
                "worker_accelerator_count": gcp_worker_gpu_count,
                "worker_accelerator_type": "nvidia-tesla-t4",
            }

        if use_scheduler_public_ip is False:
            if "use_dashboard_public_ip" not in backend_options and not use_dashboard_https:
                backend_options["use_dashboard_public_ip"] = False

        if use_dashboard_https is False:
            if "use_dashboard_https" not in backend_options:
                backend_options["use_dashboard_https"] = False

        if backend_options:
            # for backwards compatibility with v1 options
            if "region" in backend_options and "region_name" not in backend_options:
                backend_options["region_name"] = backend_options["region"]  # type: ignore
                del backend_options["region"]  # type: ignore
            if "zone" in backend_options and "zone_name" not in backend_options:
                backend_options["zone_name"] = backend_options["zone"]  # type: ignore
                del backend_options["zone"]  # type: ignore
            # firewall just lets you specify a single CIDR block to open for ingress
            # we want to support a list of ingress CIDR blocks
            if "firewall" in backend_options:
                backend_options["ingress"] = [backend_options.pop("firewall")]  # type: ignore

            # convert the list of ingress rules to the FirewallSpec expected server-side
            if "ingress" in backend_options:
                fw_spec = {"ingress": backend_options.pop("ingress")}
                backend_options["firewall_spec"] = fw_spec  # type: ignore

            validate_backend_options(backend_options)
            data["options"] = backend_options

        response = await self._do_request(
            "POST",
            self.server + f"/api/v2/clusters/account/{account}/",
            json=data,
        )

        response_json = await response.json()

        if response.status >= 400:
            from .widgets import EXECUTION_CONTEXT

            if response_json.get("code") == "NO_CLOUD_SETUP":
                server_error_message = response_json.get("message")
                error_message = f"{server_error_message} or by running `coiled setup`"

                if EXECUTION_CONTEXT == "terminal":
                    # maybe not interactive so just raise
                    raise ClusterCreationError(error_message)
                else:
                    # interactive session so let's try running the cloud setup wizard
                    if await do_setup_wizard():
                        # the user setup their cloud backend, so let's try creating cluster again!
                        response = await self._do_request(
                            "POST",
                            self.server + f"/api/v2/clusters/account/{account}/",
                            json=data,
                        )
                        if response.status >= 400:
                            await handle_api_exception(response)  # always raises exception, no return
                        response_json = await response.json()
                    else:
                        raise ClusterCreationError(error_message)
            else:
                if "message" in response_json:
                    raise ServerError(response_json["message"])
                if "detail" in response_json:
                    raise ServerError(response_json["detail"])
                raise ServerError(response_json)

        return response_json["id"], response_json["existing"]

    @overload
    def create_cluster(
        self: Cloud[Sync],
        name: str,
        *,
        software: Optional[str] = None,
        worker_class: Optional[str] = None,
        worker_options: Optional[dict] = None,
        scheduler_options: Optional[dict] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        workers: int = 0,
        environ: Optional[Dict] = None,
        tags: Optional[Dict] = None,
        dask_config: Optional[Dict] = None,
        private_to_creator: Optional[bool] = None,
        scheduler_vm_types: Optional[list] = None,
        worker_gpu_type: Optional[str] = None,
        worker_vm_types: Optional[list] = None,
        worker_disk_size: Optional[int] = None,
        worker_disk_throughput: Optional[int] = None,
        scheduler_disk_size: Optional[int] = None,
        backend_options: Optional[Union[dict, AWSOptions, GCPOptions]] = None,
    ) -> Tuple[int, bool]: ...

    @overload
    def create_cluster(
        self: Cloud[Async],
        name: str,
        *,
        software: Optional[str] = None,
        worker_class: Optional[str] = None,
        worker_options: Optional[dict] = None,
        scheduler_options: Optional[dict] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        workers: int = 0,
        environ: Optional[Dict] = None,
        tags: Optional[Dict] = None,
        dask_config: Optional[Dict] = None,
        private_to_creator: Optional[bool] = None,
        scheduler_vm_types: Optional[list] = None,
        worker_gpu_type: Optional[str] = None,
        worker_vm_types: Optional[list] = None,
        worker_disk_size: Optional[int] = None,
        worker_disk_throughput: Optional[int] = None,
        scheduler_disk_size: Optional[int] = None,
        backend_options: Optional[Union[dict, AWSOptions, GCPOptions]] = None,
    ) -> Awaitable[Tuple[int, bool]]: ...

    def create_cluster(
        self,
        name: str,
        *,
        software: Optional[str] = None,
        worker_class: Optional[str] = None,
        worker_options: Optional[dict] = None,
        scheduler_options: Optional[dict] = None,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        workers: int = 0,
        environ: Optional[Dict] = None,
        tags: Optional[Dict] = None,
        private_to_creator: Optional[bool] = None,
        dask_config: Optional[Dict] = None,
        scheduler_vm_types: Optional[list] = None,
        worker_gpu_type: Optional[str] = None,
        worker_vm_types: Optional[list] = None,
        worker_disk_size: Optional[int] = None,
        worker_disk_throughput: Optional[int] = None,
        scheduler_disk_size: Optional[int] = None,
        backend_options: Optional[Union[dict, AWSOptions, GCPOptions]] = None,
    ) -> Union[Tuple[int, bool], Awaitable[Tuple[int, bool]]]:
        return self._sync(
            self._create_cluster,
            name=name,
            software_environment=software,
            worker_class=worker_class,
            worker_options=worker_options,
            scheduler_options=scheduler_options,
            workspace=workspace or account,
            workers=workers,
            environ=environ,
            tags=tags,
            dask_config=dask_config,
            private_to_creator=private_to_creator,
            scheduler_vm_types=scheduler_vm_types,
            worker_vm_types=worker_vm_types,
            gcp_worker_gpu_type=worker_gpu_type,
            worker_disk_size=worker_disk_size,
            worker_disk_throughput=worker_disk_throughput,
            scheduler_disk_size=scheduler_disk_size,
            backend_options=backend_options,
        )

    @track_context
    async def _delete_cluster(
        self, cluster_id: int, workspace: Optional[str] = None, reason: Optional[str] = None
    ) -> None:
        workspace = workspace or self.default_workspace

        route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}"
        if reason:
            params = {"reason": reason}
        else:
            params = None
        response = await self._do_request_idempotent(
            "DELETE",
            self.server + route,
            params=params,
        )
        if response.status >= 400:
            await handle_api_exception(response)
        else:
            # multiple deletes sometimes fail if we don't await response here
            await response.json()
            logger.info(f"Cluster {cluster_id} deleted successfully.")

    @overload
    def delete_cluster(
        self: Cloud[Sync],
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        reason: Optional[str] = None,
    ) -> None: ...

    @overload
    def delete_cluster(
        self: Cloud[Async],
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        reason: Optional[str] = None,
    ) -> Awaitable[None]: ...

    @delete_docstring  # TODO: this docstring erroneously says "Name of cluster" when it really accepts an ID
    def delete_cluster(
        self,
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        reason: Optional[str] = None,
    ) -> Optional[Awaitable[None]]:
        return self._sync(self._delete_cluster, cluster_id=cluster_id, workspace=workspace or account, reason=reason)

    async def _get_cluster_state(self, cluster_id: int, workspace: Optional[str] = None) -> dict:
        workspace = workspace or self.default_workspace
        response = await self._do_request(
            "GET", self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/state"
        )
        if response.status >= 400:
            await handle_api_exception(response)
        return await response.json()

    async def _get_cluster_details(self, cluster_id: int, workspace: Optional[str] = None):
        workspace = workspace or self.default_workspace
        r = await self._do_request_idempotent(
            "GET", self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}"
        )
        if r.status >= 400:
            await handle_api_exception(r)
        return await r.json()

    def _get_cluster_details_synced(self, cluster_id: int, workspace: Optional[str] = None):
        return self._sync(
            self._get_cluster_details,
            cluster_id=cluster_id,
            workspace=workspace,
        )

    def _cluster_grafana_url(self, cluster_id: int, workspace: Optional[str] = None):
        """for internal Coiled use"""
        workspace = workspace or self.default_workspace
        details = self._sync(
            self._get_cluster_details,
            cluster_id=cluster_id,
            workspace=workspace,
        )

        return get_grafana_url(details, account=workspace, cluster_id=cluster_id)

    def cluster_details(self, cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None):
        details = self._sync(
            self._get_cluster_details,
            cluster_id=cluster_id,
            workspace=workspace or account,
        )
        state_keys = ["state", "reason", "updated"]

        def get_state(state: dict):
            return {k: v for k, v in state.items() if k in state_keys}

        def get_instance(instance):
            if instance is None:
                return None
            else:
                return {
                    "id": instance["id"],
                    "created": instance["created"],
                    "name": instance["name"],
                    "public_ip_address": instance["public_ip_address"],
                    "private_ip_address": instance["private_ip_address"],
                    "current_state": get_state(instance["current_state"]),
                }

        def get_process(process: dict):
            if process is None:
                return None
            else:
                return {
                    "created": process["created"],
                    "name": process["name"],
                    "current_state": get_state(process["current_state"]),
                    "instance": get_instance(process["instance"]),
                }

        return {
            "id": details["id"],
            "name": details["name"],
            "workers": [get_process(w) for w in details["workers"]],
            "scheduler": get_process(details["scheduler"]),
            "current_state": get_state(details["current_state"]),
            "created": details["created"],
        }

    async def _get_workers_page(self, cluster_id: int, page: int, workspace: Optional[str] = None) -> Tuple[list, bool]:
        page_size = 100
        workspace = workspace or self.default_workspace

        response = await self._do_request(
            "GET",
            self.server + f"/api/v2/workers/account/{workspace}/cluster/{cluster_id}/",
            params={"limit": page_size, "offset": page_size * page},
        )
        if response.status >= 400:
            await handle_api_exception(response)

        results = await response.json()
        has_more_pages = len(results) > 0
        return results, has_more_pages

    @track_context
    async def _get_worker_names(
        self,
        workspace: str,
        cluster_id: int,
        statuses: Optional[List[ProcessStateEnum]] = None,
    ) -> Set[str]:
        worker_infos = await self._depaginate_list(self._get_workers_page, cluster_id=cluster_id, workspace=workspace)
        logger.debug(f"workers: {worker_infos}")
        return {w["name"] for w in worker_infos if statuses is None or w["current_state"]["state"] in statuses}

    @track_context
    async def _security(self, cluster_id: int, workspace: Optional[str] = None):
        cluster = await self._get_cluster_details(cluster_id=cluster_id, workspace=workspace)
        if ProcessStateEnum(cluster["scheduler"]["current_state"]["state"]) != ProcessStateEnum.started:
            raise RuntimeError(f"Cannot get security info for cluster {cluster_id} scheduler is ready")

        public_ip = cluster["scheduler"]["instance"]["public_ip_address"]
        private_ip = cluster["scheduler"]["instance"]["private_ip_address"]
        tls_cert = cluster["cluster_options"]["tls_cert"]
        tls_key = cluster["cluster_options"]["tls_key"]
        scheduler_port = cluster["scheduler_port"]
        dashboard_address = cluster["scheduler"]["dashboard_address"]

        # TODO (Declarative): pass extra_conn_args if we care about proxying through Coiled to the scheduler
        security = GatewaySecurity(tls_key, tls_cert)

        return security, {
            "private_address": f"tls://{private_ip}:{scheduler_port}",
            "public_address": f"tls://{public_ip}:{scheduler_port}",
            "dashboard_address": dashboard_address,
        }

    @track_context
    async def _requested_workers(self, cluster_id: int, account: Optional[str] = None) -> Set[str]:
        raise NotImplementedError("TODO")

    @overload
    def requested_workers(self: Cloud[Sync], cluster_id: int, account: Optional[str] = None) -> Set[str]: ...

    @track_context
    async def _get_cluster_by_name(self, name: str, workspace: Optional[str] = None) -> int:
        workspace, name = self._normalize_name(
            name, context_workspace=workspace, allow_uppercase=True, property_name="cluster name"
        )

        response = await self._do_request(
            "GET",
            self.server + f"/api/v2/clusters/account/{workspace}/name/{name}",
        )
        if response.status == 404:
            raise DoesNotExist
        elif response.status >= 400:
            await handle_api_exception(response)

        cluster = await response.json()
        return cluster["id"]

    @overload
    def get_cluster_by_name(
        self: Cloud[Sync],
        name: str,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> int: ...

    @overload
    def get_cluster_by_name(
        self: Cloud[Async],
        name: str,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> Awaitable[int]: ...

    def get_cluster_by_name(
        self,
        name: str,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> Union[int, Awaitable[int]]:
        return self._sync(
            self._get_cluster_by_name,
            name=name,
            workspace=workspace or account,
        )

    @track_context
    async def _cluster_status(
        self,
        cluster_id: int,
        account: Optional[str] = None,
        exclude_stopped: bool = True,
    ) -> dict:
        raise NotImplementedError("TODO?")

    @track_context
    async def _get_cluster_states_declarative(
        self,
        cluster_id: int,
        workspace: Optional[str] = None,
        start_time: Optional[datetime.datetime] = None,
    ) -> int:
        workspace = workspace or self.default_workspace

        params = {"start_time": start_time.isoformat()} if start_time is not None else {}

        response = await self._do_request_idempotent(
            "GET",
            self.server + f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/states",
            params=params,
        )
        if response.status >= 400:
            await handle_api_exception(response)

        return await response.json()

    def get_cluster_states(
        self,
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        start_time: Optional[datetime.datetime] = None,
    ) -> Union[int, Awaitable[int]]:
        return self._sync(
            self._get_cluster_states_declarative,
            cluster_id=cluster_id,
            workspace=workspace or account,
            start_time=start_time,
        )

    def get_clusters_by_name(
        self,
        name: str,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> List[dict]:
        """Get all clusters matching name."""
        return self._sync(
            self._get_clusters_by_name,
            name=name,
            workspace=workspace or account,
        )

    @track_context
    async def _get_clusters_by_name(self, name: str, workspace: Optional[str] = None) -> List[dict]:
        workspace, name = self._normalize_name(name, context_workspace=workspace, allow_uppercase=True)

        response = await self._do_request(
            "GET",
            self.server + f"/api/v2/clusters/account/{workspace}/",
            params={"name": name},
        )
        if response.status == 404:
            raise DoesNotExist
        elif response.status >= 400:
            await handle_api_exception(response)

        cluster = await response.json()
        return cluster

    @overload
    def cluster_logs(
        self: Cloud[Sync],
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        scheduler: bool = True,
        workers: bool = True,
        errors_only: bool = False,
    ) -> Logs: ...

    @overload
    def cluster_logs(
        self: Cloud[Async],
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        scheduler: bool = True,
        workers: bool = True,
        errors_only: bool = False,
    ) -> Awaitable[Logs]: ...

    @track_context
    async def _cluster_logs(
        self,
        cluster_id: int,
        workspace: Optional[str] = None,
        scheduler: bool = True,
        workers: bool = True,
        errors_only: bool = False,
    ) -> Logs:
        def is_errored(process):
            process_state, instance_state = get_process_instance_state(process)
            return process_state == ProcessStateEnum.error or instance_state == InstanceStateEnum.error

        workspace = workspace or self.default_workspace

        # hits endpoint in order to get scheduler and worker instance names
        cluster_info = await self._get_cluster_details(cluster_id=cluster_id, workspace=workspace)

        try:
            scheduler_name = cluster_info["scheduler"]["instance"]["name"]
        except (TypeError, KeyError):
            # no scheduler instance name in cluster info
            logger.warning("No scheduler found when attempting to retrieve cluster logs.")
            scheduler_name = None

        worker_names = [
            worker["instance"]["name"]
            for worker in cluster_info["workers"]
            if worker["instance"] and (not errors_only or is_errored(worker))
        ]

        LabeledInstance = namedtuple("LabeledInstance", ("name", "label"))

        instances = []
        if scheduler and scheduler_name and (not errors_only or is_errored(cluster_info["scheduler"])):
            instances.append(LabeledInstance(scheduler_name, "Scheduler"))
        if workers and worker_names:
            instances.extend([LabeledInstance(worker_name, worker_name) for worker_name in worker_names])

        async def instance_log_with_semaphor(semaphor, **kwargs):
            async with semaphor:
                return await self._instance_logs(**kwargs)

        # only get 100 logs at a time; the limit here is redundant since aiohttp session already limits concurrent
        # connections but let's be safe just in case
        semaphor = asyncio.Semaphore(value=100)
        results = await asyncio.gather(*[
            instance_log_with_semaphor(semaphor=semaphor, workspace=workspace, instance_name=inst.name)
            for inst in instances
        ])

        out = {
            instance_label: instance_log
            for (_, instance_label), instance_log in zip(instances, results)
            if len(instance_log)
        }

        return Logs(out)

    def cluster_logs(
        self,
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        scheduler: bool = True,
        workers: bool = True,
        errors_only: bool = False,
    ) -> Union[Logs, Awaitable[Logs]]:
        return self._sync(
            self._cluster_logs,
            cluster_id=cluster_id,
            workspace=workspace or account,
            scheduler=scheduler,
            workers=workers,
            errors_only=errors_only,
        )

    async def _instance_logs(self, workspace: str, instance_name: str, safe=True) -> Log:
        response = await self._do_request(
            "GET",
            self.server + "/api/v2/instances/{}/instance/{}/logs".format(workspace, instance_name),
        )
        if response.status >= 400:
            if safe:
                logger.warning(f"Error retrieving logs for {instance_name}")
                return Log()
            await handle_api_exception(response)

        data = await response.json()

        messages = "\n".join(logline.get("message", "") for logline in data)

        return Log(messages)

    @overload
    def requested_workers(
        self: Cloud[Async], cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
    ) -> Awaitable[Set[str]]: ...

    def requested_workers(
        self, cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
    ) -> Union[
        Set[str],
        Awaitable[Set[str]],
    ]:
        return self._sync(self._requested_workers, cluster_id, workspace or account)

    @overload
    def scale_up(
        self: Cloud[Sync], cluster_id: int, n: int, account: Optional[str] = None, workspace: Optional[str] = None
    ) -> Optional[Dict]: ...

    @overload
    def scale_up(
        self: Cloud[Async], cluster_id: int, n: int, account: Optional[str] = None, workspace: Optional[str] = None
    ) -> Awaitable[Optional[Dict]]: ...

    def scale_up(
        self, cluster_id: int, n: int, account: Optional[str] = None, workspace: Optional[str] = None
    ) -> Union[Optional[Dict], Awaitable[Optional[Dict]]]:
        """Scale cluster to ``n`` workers

        Parameters
        ----------
        cluster_id
            Unique cluster identifier.
        n
            Number of workers to scale cluster size to.
        account
            **DEPRECATED**. Use ``workspace`` instead.
        workspace
            The Coiled workspace (previously "account") to use. If not specified,
            will check the ``coiled.workspace`` or ``coiled.account`` configuration values,
            or will use your default workspace if those aren't set.

        """
        return self._sync(self._scale_up, cluster_id, n, workspace or account)

    @overload
    def scale_down(
        self: Cloud[Sync],
        cluster_id: int,
        workers: Set[str],
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> None: ...

    @overload
    def scale_down(
        self: Cloud[Async],
        cluster_id: int,
        workers: Set[str],
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> Awaitable[None]: ...

    def scale_down(
        self,
        cluster_id: int,
        workers: Set[str],
        account: Optional[str] = None,
        workspace: Optional[str] = None,
    ) -> Optional[Awaitable[None]]:
        """Scale cluster to ``n`` workers

        Parameters
        ----------
        cluster_id
            Unique cluster identifier.
        workers
            Set of workers to scale down to.
        account
            **DEPRECATED**. Use ``workspace`` instead.
        workspace
            The Coiled workspace (previously "account") to use. If not specified,
            will check the ``coiled.workspace`` or ``coiled.account`` configuration values,
            or will use your default workspace if those aren't set.

        """
        return self._sync(self._scale_down, cluster_id, workers, workspace or account)

    @track_context
    async def _better_cluster_logs(
        self,
        cluster_id: int,
        workspace: Optional[str] = None,
        instance_ids: Optional[List[int]] = None,
        dask: bool = False,
        system: bool = False,
        since_ms: Optional[int] = None,
        until_ms: Optional[int] = None,
        filter: Optional[str] = None,
    ):
        workspace = workspace or self.default_workspace

        url_params = {}
        if dask:
            url_params["dask"] = "True"
        if system:
            url_params["system"] = "True"
        if since_ms:
            url_params["since_ms"] = f"{since_ms}"
        if until_ms:
            url_params["until_ms"] = f"{until_ms}"
        if filter:
            url_params["filter_pattern"] = f"{filter}"
        if instance_ids:
            url_params["instance_ids"] = ",".join(map(str, instance_ids))

        url_path = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/better-logs"

        response = await self._do_request(
            "GET",
            f"{self.server}{url_path}",
            params=url_params,
        )
        if response.status >= 400:
            await handle_api_exception(response)

        data = await response.json()

        return data

    def better_cluster_logs(
        self,
        cluster_id: int,
        account: Optional[str] = None,
        workspace: Optional[str] = None,
        instance_ids: Optional[List[int]] = None,
        dask: bool = False,
        system: bool = False,
        since_ms: Optional[int] = None,
        until_ms: Optional[int] = None,
        filter: Optional[str] = None,
    ) -> Logs:
        return self._sync(
            self._better_cluster_logs,
            cluster_id=cluster_id,
            workspace=workspace or account,
            instance_ids=instance_ids,
            dask=dask,
            system=system,
            since_ms=since_ms,
            until_ms=until_ms,
            filter=filter,
        )

    @track_context
    async def _scale_up(
        self, cluster_id: int, n: int, workspace: Optional[str] = None, reason: Optional[str] = None
    ) -> Dict:
        """
        Increases the number of workers by ``n``.
        """
        workspace = workspace or self.default_workspace
        data = {"n_workers": n}
        if reason:
            # pyright is annoying
            data["reason"] = reason  # type: ignore
        response = await self._do_request(
            "POST", f"{self.server}/api/v2/workers/account/{workspace}/cluster/{cluster_id}/", json=data
        )
        if response.status >= 400:
            await handle_api_exception(response)

        workers_info = await response.json()

        return {"workers": {w["name"] for w in workers_info}}

    @track_context
    async def _scale_down(
        self, cluster_id: int, workers: Set[str], workspace: Optional[str] = None, reason: Optional[str] = None
    ) -> None:
        workspace = workspace or self.default_workspace

        reason_dict = {"reason": reason} if reason else {}
        response = await self._do_request(
            "DELETE",
            f"{self.server}/api/v2/workers/account/{workspace}/cluster/{cluster_id}/",
            params={"name": workers, **reason_dict},
        )
        if response.status >= 400:
            await handle_api_exception(response)

    @overload
    def security(
        self: Cloud[Sync], cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
    ) -> Tuple[dask.distributed.Security, dict]: ...

    @overload
    def security(
        self: Cloud[Async], cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
    ) -> Awaitable[Tuple[dask.distributed.Security, dict]]: ...

    def security(
        self, cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None
    ) -> Union[
        Tuple[dask.distributed.Security, dict],
        Awaitable[Tuple[dask.distributed.Security, dict]],
    ]:
        return self._sync(self._security, cluster_id, workspace or account)

    @track_context
    async def _fetch_package_levels(self, workspace: Optional[str] = None) -> List[PackageLevel]:
        workspace = workspace or self.default_workspace
        response = await self._do_request("GET", f"{self.server}/api/v2/packages/", params={"account": workspace})
        if response.status >= 400:
            await handle_api_exception(response)
        return await response.json()

    def get_ssh_key(
        self,
        cluster_id: int,
        workspace: Optional[str] = None,
        worker: Optional[str] = None,
    ) -> dict:
        workspace = workspace or self.default_workspace
        return self._sync(
            self._get_ssh_key,
            cluster_id=cluster_id,
            workspace=workspace,
            worker=worker,
        )

    @track_context
    async def _get_ssh_key(self, cluster_id: int, workspace: str, worker: Optional[str]) -> dict:
        workspace = workspace or self.default_workspace

        route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/ssh-key"
        url = f"{self.server}{route}"

        response = await self._do_request("GET", url, params={"worker": worker} if worker else None)
        if response.status >= 400:
            await handle_api_exception(response)
        return await response.json()

    def get_cluster_log_info(
        self,
        cluster_id: int,
        workspace: Optional[str] = None,
    ) -> dict:
        workspace = workspace or self.default_workspace
        return self._sync(
            self._get_cluster_log_info,
            cluster_id=cluster_id,
            workspace=workspace,
        )

    @track_context
    async def _get_cluster_log_info(
        self,
        cluster_id: int,
        workspace: str,
    ) -> dict:
        workspace = workspace or self.default_workspace

        route = f"/api/v2/clusters/account/{workspace}/id/{cluster_id}/log-info"
        url = f"{self.server}{route}"

        response = await self._do_request("GET", url)
        if response.status >= 400:
            await handle_api_exception(response)
        return await response.json()

    @track_context
    async def _get_cluster_aggregated_metric(
        self,
        cluster_id: int,
        workspace: Optional[str],
        query: str,
        over_time: str,
        start_ts: Optional[int],
        end_ts: Optional[int],
    ):
        workspace = workspace or self.default_workspace
        route = f"/api/v2/metrics/account/{workspace}/cluster/{cluster_id}/value"
        url = f"{self.server}{route}"
        params = {"query": query, "over_time": over_time}
        if start_ts:
            params["start_ts"] = str(start_ts)
        if end_ts:
            params["end_ts"] = str(end_ts)

        response = await self._do_request("GET", url, params=params)
        if response.status >= 400:
            await handle_api_exception(response)
        return await response.json()

    @track_context
    async def _add_cluster_span(self, cluster_id: int, workspace: Optional[str], span_identifier: str, data: dict):
        workspace = workspace or self.default_workspace
        route = f"/api/v2/analytics/{workspace}/cluster/{cluster_id}/span/{span_identifier}"
        url = f"{self.server}{route}"

        response = await self._do_request("POST", url, json=data)
        if response.status >= 400:
            await handle_api_exception(response)
        return await response.json()

    def _sync_request(self, route, method: str = "GET", json_result: bool = False, **kwargs):
        url = f"{self.server}{route}"
        response = self._sync(self._do_request, url=url, method=method, **kwargs)
        if response.status >= 400:
            raise ServerError(f"{url} returned {response.status}")

        async def get_result(r):
            return await (r.json() if json_result else r.text())

        return self._sync(
            get_result,
            response,
        )


Cloud = CloudV2


[docs] def cluster_logs( cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None, scheduler: bool = True, workers: bool = True, errors_only: bool = False, ): """ Returns cluster logs as a dictionary, with a key for the scheduler and each worker. .. versionchanged:: 0.2.0 ``cluster_name`` is no longer accepted, use ``cluster_id`` instead. """ with CloudV2() as cloud: return cloud.cluster_logs( cluster_id=cluster_id, workspace=workspace or account, scheduler=scheduler, workers=workers, errors_only=errors_only, )
def better_cluster_logs( cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None, instance_ids: Optional[List[int]] = None, dask: bool = False, system: bool = False, since_ms: Optional[int] = None, until_ms: Optional[int] = None, filter: Optional[str] = None, ): """ Pull logs for the cluster using better endpoint. Logs for recent clusters are split between system and container (dask), you can get either or both (or none). since_ms and until_ms are both inclusive (you'll get logs with timestamp matching those). """ with Cloud() as cloud: return cloud.better_cluster_logs( cluster_id=cluster_id, workspace=workspace or account, instance_ids=instance_ids, dask=dask, system=system, since_ms=since_ms, until_ms=until_ms, filter=filter, ) def cluster_details( cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None, ) -> dict: """ Get details of a cluster as a dictionary. """ with CloudV2() as cloud: return cloud.cluster_details( cluster_id=cluster_id, workspace=workspace or account, ) def log_cluster_debug_info( cluster_id: int, account: Optional[str] = None, workspace: Optional[str] = None, ): with CloudV2() as cloud: details = cloud.cluster_details(cluster_id, workspace or account) logger.debug("Cluster details:") logger.debug(json.dumps(details, indent=2)) states_by_type = cloud.get_cluster_states(cluster_id, workspace or account) logger.debug("cluster state history:") log_states(flatten_log_states(states_by_type), level=logging.DEBUG) def create_cluster( name: str, *, software: Optional[str] = None, worker_options: Optional[dict] = None, scheduler_options: Optional[dict] = None, account: Optional[str] = None, workspace: Optional[str] = None, workers: int = 0, environ: Optional[Dict] = None, tags: Optional[Dict] = None, dask_config: Optional[Dict] = None, private_to_creator: Optional[bool] = None, scheduler_vm_types: Optional[list] = None, worker_vm_types: Optional[list] = None, worker_disk_size: Optional[int] = None, worker_disk_throughput: Optional[int] = None, scheduler_disk_size: Optional[int] = None, backend_options: Optional[Union[dict, AWSOptions, GCPOptions]] = None, ) -> int: """Create a cluster Parameters --------- name Name of cluster. software Identifier of the software environment to use, in the format (<account>/)<name>. If the software environment is owned by the same account as that passed into "account", the (<account>/) prefix is optional. For example, suppose your account is "wondercorp", but your friends at "friendlycorp" have an environment named "xgboost" that you want to use; you can specify this with "friendlycorp/xgboost". If you simply entered "xgboost", this is shorthand for "wondercorp/xgboost". The "name" portion of (<account>/)<name> can only contain ASCII letters, hyphens and underscores. worker_options Mapping with keyword arguments to pass to ``worker_class``. Defaults to ``{}``. scheduler_options Mapping with keyword arguments to pass to the Scheduler ``__init__``. Defaults to ``{}``. account **DEPRECATED**. Use ``workspace`` instead. workspace The Coiled workspace (previously "account") to use. If not specified, will check the ``coiled.workspace`` or ``coiled.account`` configuration values, or will use your default workspace if those aren't set. workers Number of workers we to launch. environ Dictionary of environment variables. tags Dictionary of instance tags dask_config Dictionary of dask config to put on cluster See Also -------- coiled.Cluster """ with CloudV2(account=workspace or account) as cloud: cluster, existing = cloud.create_cluster( name=name, software=software, worker_options=worker_options, scheduler_options=scheduler_options, workspace=workspace or account, workers=workers, environ=environ, tags=tags, dask_config=dask_config, private_to_creator=private_to_creator, backend_options=backend_options, worker_vm_types=worker_vm_types, worker_disk_size=worker_disk_size, worker_disk_throughput=worker_disk_throughput, scheduler_disk_size=scheduler_disk_size, scheduler_vm_types=scheduler_vm_types, ) return cluster
[docs] @list_docstring def list_clusters(account=None, workspace=None, max_pages: Optional[int] = None): with CloudV2() as cloud: return cloud.list_clusters(workspace=workspace or account, max_pages=max_pages)
[docs] @delete_docstring def delete_cluster(name: str, account: Optional[str] = None, workspace: Optional[str] = None): with CloudV2() as cloud: cluster_id = cloud.get_cluster_by_name(name=name, workspace=workspace or account) if cluster_id is not None: return cloud.delete_cluster(cluster_id=cluster_id, workspace=workspace or account)