diff --git a/docs/docs/concepts/backends.md b/docs/docs/concepts/backends.md index 9ad59ff92..4c5606206 100644 --- a/docs/docs/concepts/backends.md +++ b/docs/docs/concepts/backends.md @@ -1102,8 +1102,11 @@ projects: - apiGroups: [""] resources: ["nodes"] verbs: ["list", "get"] + - apiGroups: [""] + resources: ["persistentvolumeclaims"] + verbs: ["get", "create", "delete"] ``` - + Ensure you've created a ClusterRoleBinding to grant the role to the user or the service account you're using. ??? info "Resources and offers" diff --git a/docs/docs/reference/dstack.yml/volume.md b/docs/docs/reference/dstack.yml/volume.md index af34a166a..2675b684e 100644 --- a/docs/docs/reference/dstack.yml/volume.md +++ b/docs/docs/reference/dstack.yml/volume.md @@ -2,10 +2,61 @@ The `volume` configuration type allows creating, registering, and updating [volumes](../../concepts/volumes.md). -## Root reference +=== "AWS" -#SCHEMA# dstack._internal.core.models.volumes.VolumeConfiguration - overrides: - show_root_heading: false - type: - required: true + #SCHEMA# dstack._internal.core.models.volumes.AWSVolumeConfiguration + overrides: + show_root_heading: false + backend: + required: true + +=== "GCP" + + #SCHEMA# dstack._internal.core.models.volumes.GCPVolumeConfiguration + overrides: + show_root_heading: false + backend: + required: true + +=== "Runpod" + + #SCHEMA# dstack._internal.core.models.volumes.RunpodVolumeConfiguration + overrides: + show_root_heading: false + backend: + required: true + +=== "Kubernetes" + + Kubernetes backend volumes are mapped to [`PersistentVolumeClaim`](https://kubernetes.io/docs/concepts/storage/persistent-volumes/#persistentvolumeclaims) objects. + + To create a new claim, specify `size` and optionally `storage_class_name` and/or `access_modes`: + + ```yaml + type: volume + backend: kubernetes + name: new-volume + size: 100GB + # By default, storage_class_name is not set, and the decision is delegated to + # the DefaultStorageClass admission controller (if it is enabled) + storage_class_name: test-nfs + # access_modes defaults to [ReadWriteOnce]. For multi-attach-capable volumes + # use ReadWriteMany and/or ReadOnlyMany + access_modes: + - ReadWriteMany + ``` + + To reuse an existing claim, specify `claim_name`: + + ```yaml + type: volume + backend: kubernetes + name: existing-volume + claim_name: existing-pvc + ``` + + #SCHEMA# dstack._internal.core.models.volumes.KubernetesVolumeConfiguration + overrides: + show_root_heading: false + backend: + required: true diff --git a/scripts/docs/gen_schema_reference.py b/scripts/docs/gen_schema_reference.py index 01514b34d..f141200cc 100644 --- a/scripts/docs/gen_schema_reference.py +++ b/scripts/docs/gen_schema_reference.py @@ -8,6 +8,7 @@ import re from enum import Enum from fnmatch import fnmatch +from typing import Optional import mkdocs_gen_files import yaml @@ -85,7 +86,7 @@ def get_friendly_type(annotation: Type) -> str: # Handle Literal — list values if get_origin(annotation) is Literal: - values = get_args(annotation) + values = [v.value if isinstance(v, Enum) else v for v in get_args(annotation)] return " | ".join(f'"{v}"' for v in values) # Handle list @@ -207,11 +208,12 @@ def _enrich_type_from_schema(friendly_type: str, prop_schema: Dict[str, Any]) -> def generate_schema_reference( model_path: str, *, - overrides: Dict[str, Dict[str, Any]] = None, + overrides: Optional[dict[str, dict[str, Any]]] = None, prefix: str = "", ) -> str: module, model_name = model_path.rsplit(".", maxsplit=1) cls = getattr(importlib.import_module(module), model_name) + assert issubclass(cls, BaseModel) rows = [] if ( not overrides diff --git a/src/dstack/_internal/cli/services/configurators/volume.py b/src/dstack/_internal/cli/services/configurators/volume.py index 624c2080c..2449de0da 100644 --- a/src/dstack/_internal/cli/services/configurators/volume.py +++ b/src/dstack/_internal/cli/services/configurators/volume.py @@ -14,8 +14,9 @@ from dstack._internal.core.errors import ResourceNotExistsError from dstack._internal.core.models.configurations import ApplyConfigurationType from dstack._internal.core.models.volumes import ( + AnyVolumeConfiguration, Volume, - VolumeConfiguration, + VolumeConfigurationWithRegion, VolumePlan, VolumeSpec, VolumeStatus, @@ -24,12 +25,12 @@ from dstack.api._public import Client -class VolumeConfigurator(BaseApplyConfigurator[VolumeConfiguration]): +class VolumeConfigurator(BaseApplyConfigurator[AnyVolumeConfiguration]): TYPE = ApplyConfigurationType.VOLUME def apply_configuration( self, - conf: VolumeConfiguration, + conf: AnyVolumeConfiguration, configuration_path: str, command_args: argparse.Namespace, configurator_args: argparse.Namespace, @@ -129,7 +130,7 @@ def apply_configuration( def delete_configuration( self, - conf: VolumeConfiguration, + conf: AnyVolumeConfiguration, configuration_path: str, command_args: argparse.Namespace, ): @@ -165,7 +166,7 @@ def register_args(cls, parser: argparse.ArgumentParser): help="The volume name", ) - def apply_args(self, conf: VolumeConfiguration, args: argparse.Namespace): + def apply_args(self, conf: AnyVolumeConfiguration, args: argparse.Namespace): if args.name: conf.name = args.name @@ -206,12 +207,13 @@ def th(s: str) -> str: size = "-" if plan.spec.configuration.size is not None: size = str(plan.spec.configuration.size) - if plan.spec.configuration.volume_id is not None: + if plan.spec.configuration.is_external: volume_type = "external" configuration_table.add_row(th("Volume type"), volume_type) configuration_table.add_row(th("Backend"), plan.spec.configuration.backend.value) - configuration_table.add_row(th("Region"), plan.spec.configuration.region) + if isinstance(plan.spec.configuration, VolumeConfigurationWithRegion): + configuration_table.add_row(th("Region"), plan.spec.configuration.region) configuration_table.add_row(th("Size"), size) console.print(configuration_table) diff --git a/src/dstack/_internal/cli/utils/run.py b/src/dstack/_internal/cli/utils/run.py index 0095feae2..351bea9c0 100644 --- a/src/dstack/_internal/cli/utils/run.py +++ b/src/dstack/_internal/cli/utils/run.py @@ -186,9 +186,12 @@ def th(s: str) -> str: instance = offer.instance.name if offer.total_blocks > 1: instance += f" ({offer.blocks}/{offer.total_blocks})" + offer_backend = offer.backend.replace("remote", "ssh") + if offer.region: + offer_backend = f"{offer_backend} ({offer.region})" offers.add_row( f"{i}", - offer.backend.replace("remote", "ssh") + " (" + offer.region + ")", + offer_backend, r.pretty_format(include_spot=True), instance, f"${offer.price:.4f}".rstrip("0").rstrip("."), @@ -394,6 +397,8 @@ def _format_backend(backend_type: BackendType, region: str) -> str: backend_str = backend_type.value if backend_type == BackendType.REMOTE: backend_str = "ssh" + if not region: + return backend_str return f"{backend_str} ({region})" diff --git a/src/dstack/_internal/cli/utils/volume.py b/src/dstack/_internal/cli/utils/volume.py index a3f652100..d9cec1e57 100644 --- a/src/dstack/_internal/cli/utils/volume.py +++ b/src/dstack/_internal/cli/utils/volume.py @@ -29,16 +29,17 @@ def get_volumes_table( table.add_column("ERROR") for volume in volumes: - backend = f"{volume.configuration.backend.value} ({volume.configuration.region})" - region = volume.configuration.region + backend = volume.get_backend().value + region = volume.get_region() if verbose: - backend = volume.configuration.backend.value - if ( - verbose - and volume.provisioning_data is not None - and volume.provisioning_data.availability_zone is not None - ): - region += f" ({volume.provisioning_data.availability_zone})" + # In verbose mode, BACKEND displays `backend` only, and REGION displays nothing or + # `region` or `region (az)` + if availability_zone := volume.get_availability_zone(): + region = f"{region} ({availability_zone})" + elif region: + # In non-verbose mode, BACKEND displays `backend` or `backend (region)`, and REGION + # is hidden + backend = f"{backend} ({region})" attached = "-" if volume.attachments is not None: attached = ", ".join( diff --git a/src/dstack/_internal/core/backends/aws/compute.py b/src/dstack/_internal/core/backends/aws/compute.py index cd472ae11..91899995a 100644 --- a/src/dstack/_internal/core/backends/aws/compute.py +++ b/src/dstack/_internal/core/backends/aws/compute.py @@ -72,6 +72,7 @@ from dstack._internal.core.models.resources import Memory, Range from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.core.models.volumes import ( + AWSVolumeConfiguration, Volume, VolumeAttachmentData, VolumeProvisioningData, @@ -688,6 +689,7 @@ def terminate_gateway( logger.debug("Deleted ALB resources for gateway %s", configuration.instance_name) def register_volume(self, volume: Volume) -> VolumeProvisioningData: + assert isinstance(volume.configuration, AWSVolumeConfiguration) ec2_client = self.session.client("ec2", region_name=volume.configuration.region) logger.debug("Requesting EBS volume %s", volume.configuration.volume_id) @@ -715,6 +717,7 @@ def register_volume(self, volume: Volume) -> VolumeProvisioningData: ) def create_volume(self, volume: Volume) -> VolumeProvisioningData: + assert isinstance(volume.configuration, AWSVolumeConfiguration) ec2_client = self.session.client("ec2", region_name=volume.configuration.region) volume_name = generate_unique_volume_name(volume) @@ -773,6 +776,7 @@ def create_volume(self, volume: Volume) -> VolumeProvisioningData: ) def delete_volume(self, volume: Volume): + assert isinstance(volume.configuration, AWSVolumeConfiguration) ec2_client = self.session.client("ec2", region_name=volume.configuration.region) logger.debug("Deleting EBS volume %s", volume.configuration.name) @@ -788,6 +792,7 @@ def delete_volume(self, volume: Volume): def attach_volume( self, volume: Volume, provisioning_data: JobProvisioningData ) -> VolumeAttachmentData: + assert isinstance(volume.configuration, AWSVolumeConfiguration) ec2_client = self.session.client("ec2", region_name=volume.configuration.region) instance_id = provisioning_data.instance_id @@ -826,6 +831,7 @@ def attach_volume( def detach_volume( self, volume: Volume, provisioning_data: JobProvisioningData, force: bool = False ): + assert isinstance(volume.configuration, AWSVolumeConfiguration) ec2_client = self.session.client("ec2", region_name=volume.configuration.region) instance_id = provisioning_data.instance_id @@ -848,6 +854,7 @@ def detach_volume( logger.debug("Detached EBS volume %s from instance %s", volume.volume_id, instance_id) def is_volume_detached(self, volume: Volume, provisioning_data: JobProvisioningData) -> bool: + assert isinstance(volume.configuration, AWSVolumeConfiguration) ec2_client = self.session.client("ec2", region_name=volume.configuration.region) instance_id = provisioning_data.instance_id diff --git a/src/dstack/_internal/core/backends/gcp/compute.py b/src/dstack/_internal/core/backends/gcp/compute.py index a7f7d9673..7a0101ed0 100644 --- a/src/dstack/_internal/core/backends/gcp/compute.py +++ b/src/dstack/_internal/core/backends/gcp/compute.py @@ -71,6 +71,7 @@ from dstack._internal.core.models.resources import Memory, Range from dstack._internal.core.models.runs import JobProvisioningData, Requirements from dstack._internal.core.models.volumes import ( + GCPVolumeConfiguration, Volume, VolumeAttachmentData, VolumeProvisioningData, @@ -645,6 +646,7 @@ def terminate_gateway( ) def register_volume(self, volume: Volume) -> VolumeProvisioningData: + assert isinstance(volume.configuration, GCPVolumeConfiguration) logger.debug("Requesting persistent disk %s", volume.configuration.volume_id) zones = gcp_resources.get_availability_zones( regions_client=self.regions_client, @@ -676,6 +678,7 @@ def register_volume(self, volume: Volume) -> VolumeProvisioningData: raise ComputeError(f"Persistent disk {volume.configuration.volume_id} not found") def create_volume(self, volume: Volume) -> VolumeProvisioningData: + assert isinstance(volume.configuration, GCPVolumeConfiguration) zones = gcp_resources.get_availability_zones( regions_client=self.regions_client, project_id=self.config.project_id, diff --git a/src/dstack/_internal/core/backends/kubernetes/compute.py b/src/dstack/_internal/core/backends/kubernetes/compute.py index b8a48429c..4ea833bd0 100644 --- a/src/dstack/_internal/core/backends/kubernetes/compute.py +++ b/src/dstack/_internal/core/backends/kubernetes/compute.py @@ -3,6 +3,7 @@ import subprocess import tempfile import time +from decimal import Decimal from enum import Enum from typing import List, Optional @@ -17,11 +18,14 @@ ComputeWithInstanceVolumesSupport, ComputeWithMultinodeSupport, ComputeWithPrivilegedSupport, + ComputeWithVolumeSupport, generate_unique_gateway_instance_name, generate_unique_instance_name_for_job, generate_unique_name, + generate_unique_volume_name, get_docker_commands, get_dstack_gateway_commands, + merge_tags, ) from dstack._internal.core.backends.kubernetes.models import ( KubernetesConfig, @@ -32,13 +36,15 @@ AMD_GPU_NAME_TO_DEVICE_IDS, AMD_GPU_NODE_TAINT, AMD_GPU_RESOURCE, - DUMMY_REGION, NVIDIA_GPU_NAME_TO_GPU_INFO, NVIDIA_GPU_NODE_TAINT, NVIDIA_GPU_PRODUCT_LABEL, NVIDIA_GPU_RESOURCE, + OBJECT_NAME_MAX_LENGTH, PodPhase, TaintEffect, + filter_invalid_labels, + format_dstack_label_key, format_memory, get_amd_gpu_from_node_labels, get_gpu_request_from_gpu_spec, @@ -49,6 +55,7 @@ get_nvidia_gpu_from_node_labels, is_hard_taint, is_taint_tolerated, + parse_quantity, ) from dstack._internal.core.backends.kubernetes.utils import ( call_api_method, @@ -56,6 +63,7 @@ ) from dstack._internal.core.consts import DSTACK_RUNNER_SSH_PORT from dstack._internal.core.errors import ComputeError, ProvisioningError +from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.gateways import ( GatewayComputeConfiguration, @@ -70,7 +78,13 @@ from dstack._internal.core.models.resources import CPUSpec, GPUSpec from dstack._internal.core.models.routers import AnyGatewayRouterConfig from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import InstanceMountPoint, Volume +from dstack._internal.core.models.volumes import ( + InstanceMountPoint, + KubernetesVolumeConfiguration, + Volume, + VolumeMountPoint, + VolumeProvisioningData, +) from dstack._internal.utils.common import get_or_error from dstack._internal.utils.logging import get_logger @@ -100,6 +114,7 @@ class KubernetesCompute( ComputeWithFilteredOffersCached, ComputeWithPrivilegedSupport, ComputeWithInstanceVolumesSupport, + ComputeWithVolumeSupport, ComputeWithGatewaySupport, ComputeWithMultinodeSupport, Compute, @@ -204,27 +219,60 @@ def run_job( ) ) + volume_name_path_map: dict[str, str] = {} mount_points = job.job_spec.volumes if mount_points is None: # Legacy JobSpec without volumes mount_points = run.run_spec.configuration.volumes for mount_point in mount_points: - assert isinstance(mount_point, InstanceMountPoint) - # "Must be a DNS_LABEL and unique within the pod" - volume_name = generate_unique_name(prefix="host-path", max_length=253) + if isinstance(mount_point, VolumeMountPoint): + if isinstance(mount_point.name, str): + volume_names = [mount_point.name] + else: + volume_names = mount_point.name + for volume_name in volume_names: + volume_name_path_map[volume_name] = mount_point.path + elif isinstance(mount_point, InstanceMountPoint): + # "Must be a DNS_LABEL and unique within the pod" + volume_name = generate_unique_name( + prefix="host-path", max_length=OBJECT_NAME_MAX_LENGTH + ) + volumes_.append( + client.V1Volume( + name=volume_name, + host_path=client.V1HostPathVolumeSource( + path=mount_point.instance_path, + type="DirectoryOrCreate", + ), + ), + ) + volume_mounts.append( + client.V1VolumeMount( + name=volume_name, + mount_path=mount_point.path, + ) + ) + else: + assert False, f"unexpected mount point: {mount_point!r}" + for volume in volumes: + pvc_name = volume.volume_id + assert pvc_name is not None, f"missing PVC name: {volume!r}" + mount_path = volume_name_path_map.get(volume.name) + assert mount_path is not None, f"missing mount path: {volume!r}" + volume_name = generate_unique_name(prefix="pvc", max_length=OBJECT_NAME_MAX_LENGTH) volumes_.append( client.V1Volume( name=volume_name, - host_path=client.V1HostPathVolumeSource( - path=mount_point.instance_path, - type="DirectoryOrCreate", + persistent_volume_claim=client.V1PersistentVolumeClaimVolumeSource( + claim_name=pvc_name, + read_only=False, ), ), ) volume_mounts.append( client.V1VolumeMount( name=volume_name, - mount_path=mount_point.path, + mount_path=mount_path, ) ) @@ -481,9 +529,8 @@ def create_gateway( namespace=self.config.namespace, service_name=_get_pod_service_name(instance_name), ) - region = DUMMY_REGION if address is None: - self.terminate_instance(instance_name, region=region) + self.terminate_instance(instance_name, region="") raise ComputeError( "Failed to get gateway hostname. " "Ensure the Kubernetes cluster supports Load Balancer services." @@ -491,7 +538,7 @@ def create_gateway( return GatewayProvisioningData( instance_id=instance_name, ip_address=address, - region=region, + region="", ) def terminate_gateway( @@ -506,6 +553,102 @@ def terminate_gateway( backend_data=backend_data, ) + def register_volume(self, volume: Volume) -> VolumeProvisioningData: + assert isinstance(volume.configuration, KubernetesVolumeConfiguration) + pvc_name = volume.configuration.claim_name + assert pvc_name is not None + + pvc = call_api_method( + self.api.read_namespaced_persistent_volume_claim, + expected=404, + namespace=self.config.namespace, + name=pvc_name, + ) + if pvc is None: + raise ComputeError(f"PersistentVolumeClaim {pvc_name} not found") + + capacity_bytes: Optional[Decimal] = None + if pvc.status is not None: + actual_capacity_qty = (pvc.status.capacity or {}).get("storage") + if actual_capacity_qty is not None: + capacity_bytes = parse_quantity(actual_capacity_qty) + if capacity_bytes is None and pvc.spec is not None and pvc.spec.resources is not None: + requested_capacity_qty = (pvc.spec.resources.requests or {}).get("storage") + if requested_capacity_qty is not None: + capacity_bytes = parse_quantity(requested_capacity_qty) + if capacity_bytes is None: + raise ComputeError(f"Failed to detect PersistentVolumeClaim {pvc_name} capacity") + + return VolumeProvisioningData( + backend=BackendType.KUBERNETES, + volume_id=pvc_name, + size_gb=int(capacity_bytes // 2**30), + attachable=False, + detachable=False, + ) + + def create_volume(self, volume: Volume) -> VolumeProvisioningData: + assert isinstance(volume.configuration, KubernetesVolumeConfiguration) + assert volume.configuration.size is not None + + labels = { + format_dstack_label_key("owner"): "dstack", + format_dstack_label_key("project"): volume.project_name, + format_dstack_label_key("name"): volume.name, + format_dstack_label_key("user"): volume.user, + } + labels = merge_tags( + base_tags=labels, + resource_tags=volume.configuration.tags, + ) + labels = filter_invalid_labels(labels) + + pvc_name = generate_unique_volume_name(volume, max_length=OBJECT_NAME_MAX_LENGTH) + pvc = client.V1PersistentVolumeClaim( + metadata=client.V1ObjectMeta( + name=pvc_name, + labels=labels, + ), + spec=client.V1PersistentVolumeClaimSpec( + access_modes=volume.configuration.access_modes, + storage_class_name=volume.configuration.storage_class_name, + resources=client.V1VolumeResourceRequirements( + requests={ + "storage": format_memory(volume.configuration.size), + }, + ), + ), + ) + self.api.create_namespaced_persistent_volume_claim( + namespace=self.config.namespace, + body=pvc, + ) + logger.debug("Created PVC %s for volume %s", pvc_name, volume.name) + + return VolumeProvisioningData( + backend=BackendType.KUBERNETES, + volume_id=pvc_name, + size_gb=volume.configuration.size_gb, + attachable=False, + detachable=False, + ) + + def delete_volume(self, volume: Volume): + assert isinstance(volume.configuration, KubernetesVolumeConfiguration) + pvc_name = volume.volume_id + assert pvc_name is not None + + pvc = call_api_method( + self.api.delete_namespaced_persistent_volume_claim, + expected=404, + namespace=self.config.namespace, + name=pvc_name, + ) + if pvc is None: + logger.debug("PVC %s for volume %s not found", pvc_name, volume.name) + else: + logger.debug("Deleted PVC %s for volume %s", pvc_name, volume.name) + def _get_pod_spec_parameters_for_gpu( api: client.CoreV1Api, gpu_spec: GPUSpec diff --git a/src/dstack/_internal/core/backends/kubernetes/resources.py b/src/dstack/_internal/core/backends/kubernetes/resources.py index d5cb1739f..ef443583a 100644 --- a/src/dstack/_internal/core/backends/kubernetes/resources.py +++ b/src/dstack/_internal/core/backends/kubernetes/resources.py @@ -1,4 +1,5 @@ import dataclasses +import re from collections.abc import Mapping from decimal import Decimal from enum import Enum @@ -30,7 +31,18 @@ logger = get_logger(__name__) -DUMMY_REGION = "-" +# https://kubernetes.io/docs/concepts/overview/working-with-objects/names/#names +OBJECT_NAME_MAX_LENGTH = 253 + +# https://kubernetes.io/docs/concepts/overview/working-with-objects/labels/#syntax-and-character-set +LABEL_KEY_PREFIX_MAX_LENGTH = 253 +LABEL_KEY_PREFIX_REGEX = re.compile( + r"^[a-z0-9](?:[a-z0-9-]*[a-z0-9])?(?:\.[a-z0-9]([a-z0-9-]*[a-z0-9])?)*$" +) +LABEL_KEY_NAME_MAX_LENGTH = 63 +LABEL_KEY_NAME_REGEX = re.compile(r"^[A-Za-z0-9](?:[A-Za-z0-9_.-]*[A-Za-z0-9])?$") +LABEL_VALUE_MAX_LENGTH = 63 +LABEL_VALUE_REGEX = re.compile(r"^(?:[A-Za-z0-9](?:[A-Za-z0-9_.-]*[A-Za-z0-9])?)?$") NVIDIA_GPU_RESOURCE = "nvidia.com/gpu" NVIDIA_GPU_NODE_TAINT = NVIDIA_GPU_RESOURCE @@ -120,6 +132,53 @@ def __sub__(self, other: Self) -> Self: return type(self)(**dct) +def filter_invalid_labels(labels: dict[str, str]) -> dict[str, str]: + filtered_labels: dict[str, str] = {} + for k, v in labels.items(): + try: + validate_label_key(k) + validate_label_value(v) + except ValueError as e: + logger.warning("Skipping invalid label %s=%s: %s", k, v, e) + continue + filtered_labels[k] = v + return filtered_labels + + +def validate_label_key(key: str) -> None: + parts = key.split("/") + if len(parts) > 2: + raise ValueError("Too many segments") + name: str + if len(parts) == 2: + prefix, name = parts + if not prefix: + raise ValueError("Empty prefix") + if len(prefix) > LABEL_KEY_PREFIX_MAX_LENGTH: + raise ValueError("Prefix too long") + if LABEL_KEY_PREFIX_REGEX.fullmatch(prefix) is None: + raise ValueError("Invalid prefix") + else: + name = parts[0] + if not name: + raise ValueError("Empty name") + if len(name) > LABEL_KEY_NAME_MAX_LENGTH: + raise ValueError("Name too long") + if LABEL_KEY_NAME_REGEX.fullmatch(name) is None: + raise ValueError("Invalid name") + + +def validate_label_value(value: str) -> None: + if len(value) > LABEL_VALUE_MAX_LENGTH: + raise ValueError("Value too long") + if LABEL_VALUE_REGEX.fullmatch(value) is None: + raise ValueError("Invalid value") + + +def format_dstack_label_key(name: str) -> str: + return f"k8s.dstack.ai/{name}" + + parse_quantity = cast( Callable[[Union[str, int, float, Decimal]], Decimal], _kubernetes_utils.parse_quantity ) @@ -306,7 +365,7 @@ def _get_instance_offer_from_node( ), ), price=0, - region=DUMMY_REGION, + region="", availability=InstanceAvailability.AVAILABLE, instance_runtime=InstanceRuntime.RUNNER, ) diff --git a/src/dstack/_internal/core/backends/runpod/compute.py b/src/dstack/_internal/core/backends/runpod/compute.py index 85349fea4..52bc1da9e 100644 --- a/src/dstack/_internal/core/backends/runpod/compute.py +++ b/src/dstack/_internal/core/backends/runpod/compute.py @@ -39,7 +39,11 @@ from dstack._internal.core.models.placement import PlacementGroup from dstack._internal.core.models.resources import Memory, Range from dstack._internal.core.models.runs import Job, JobProvisioningData, Requirements, Run -from dstack._internal.core.models.volumes import Volume, VolumeProvisioningData +from dstack._internal.core.models.volumes import ( + RunpodVolumeConfiguration, + Volume, + VolumeProvisioningData, +) from dstack._internal.utils.common import get_current_datetime, get_or_error from dstack._internal.utils.logging import get_logger @@ -389,6 +393,7 @@ def update_provisioning_data( provisioning_data.ssh_port = port["publicPort"] def register_volume(self, volume: Volume) -> VolumeProvisioningData: + assert isinstance(volume.configuration, RunpodVolumeConfiguration) volume_data = self.api_client.get_network_volume( volume_id=get_or_error(volume.configuration.volume_id) ) @@ -405,6 +410,7 @@ def register_volume(self, volume: Volume) -> VolumeProvisioningData: ) def create_volume(self, volume: Volume) -> VolumeProvisioningData: + assert isinstance(volume.configuration, RunpodVolumeConfiguration) volume_name = generate_unique_volume_name(volume, max_length=MAX_RESOURCE_NAME_LEN) size_gb = volume.configuration.size_gb # Runpod regions must be uppercase. diff --git a/src/dstack/_internal/core/compatibility/volumes.py b/src/dstack/_internal/core/compatibility/volumes.py index 191005e13..c66819724 100644 --- a/src/dstack/_internal/core/compatibility/volumes.py +++ b/src/dstack/_internal/core/compatibility/volumes.py @@ -1,5 +1,5 @@ from dstack._internal.core.models.common import IncludeExcludeDictType -from dstack._internal.core.models.volumes import VolumeConfiguration, VolumeSpec +from dstack._internal.core.models.volumes import AnyVolumeConfiguration, VolumeSpec def get_volume_spec_excludes(volume_spec: VolumeSpec) -> IncludeExcludeDictType: @@ -13,7 +13,7 @@ def get_volume_spec_excludes(volume_spec: VolumeSpec) -> IncludeExcludeDictType: return spec_excludes -def get_create_volume_excludes(configuration: VolumeConfiguration) -> IncludeExcludeDictType: +def get_create_volume_excludes(configuration: AnyVolumeConfiguration) -> IncludeExcludeDictType: """ Returns an exclude mapping to exclude certain fields from the create volume request. Use this method to exclude new fields when they are not set to keep @@ -25,7 +25,7 @@ def get_create_volume_excludes(configuration: VolumeConfiguration) -> IncludeExc def _get_volume_configuration_excludes( - configuration: VolumeConfiguration, + configuration: AnyVolumeConfiguration, ) -> IncludeExcludeDictType: configuration_excludes: IncludeExcludeDictType = {} diff --git a/src/dstack/_internal/core/models/configurations.py b/src/dstack/_internal/core/models/configurations.py index 3d2d30683..07acebe61 100644 --- a/src/dstack/_internal/core/models/configurations.py +++ b/src/dstack/_internal/core/models/configurations.py @@ -31,7 +31,14 @@ from dstack._internal.core.models.routers import AnyServiceRouterConfig from dstack._internal.core.models.services import AnyModel, OpenAIChatModel from dstack._internal.core.models.unix import UnixUser -from dstack._internal.core.models.volumes import MountPoint, VolumeConfiguration, parse_mount_point +from dstack._internal.core.models.volumes import ( + AnyVolumeConfiguration, + BaseVolumeConfiguration, + MountPoint, + VolumeConfiguration, + parse_mount_point, + parse_volume_configuration, +) from dstack._internal.core.services import is_valid_replica_group_name from dstack._internal.utils.common import has_duplicates, list_enum_values_for_annotation from dstack._internal.utils.json_schema import add_extra_schema_types @@ -1114,26 +1121,55 @@ class ApplyConfigurationType(str, Enum): AnyRunConfiguration, FleetConfiguration, GatewayConfiguration, - VolumeConfiguration, + AnyVolumeConfiguration, ] -class ApplyConfiguration(CoreModel): +class BaseApplyConfiguration(CoreModel): + """ + `BaseApplyConfiguration` parses the configuration based on the `type` discriminator field, + but further dispatching (reparsing) may be required if there is another discriminator field, + e.g., `BaseVolumeConfiguration` should be parsed again to get a backend-specific configuration + based on the `backend` discriminator field. + + Don't use this model directly, use `parse_apply_configuration()` instead. + """ + __root__: Annotated[ - AnyApplyConfiguration, + Union[ + # Final configurations + AnyRunConfiguration, + FleetConfiguration, + GatewayConfiguration, + # Base configurations (further parsing required to get a concrete AnyApplyConfiguration) + BaseVolumeConfiguration, + ], Field(discriminator="type"), ] def parse_apply_configuration(data: dict) -> AnyApplyConfiguration: try: - conf = ApplyConfiguration.parse_obj(data).__root__ + # First-pass parsing ignoring extra fields, to get the base (or final) configuration + conf = BaseApplyConfiguration.__response__.parse_obj(data).__root__ + if not isinstance(conf, BaseVolumeConfiguration): + # If it's a final configuration (currently, any configuration other than + # BaseVolumeConfiguration), parse again rejecting extra fields + # for validation purposes only and return the final configuration + _ = BaseApplyConfiguration.parse_obj(data).__root__ + return conf except ValidationError as e: raise ConfigurationError(e) - return conf + # Otherwise, delegate further parsing to more specific parser + return parse_volume_configuration(data) -AnyDstackConfiguration = AnyApplyConfiguration +AnyDstackConfiguration = Union[ + AnyRunConfiguration, + FleetConfiguration, + GatewayConfiguration, + VolumeConfiguration, +] class DstackConfiguration(CoreModel): diff --git a/src/dstack/_internal/core/models/volumes.py b/src/dstack/_internal/core/models/volumes.py index 701611402..fbbd0b155 100644 --- a/src/dstack/_internal/core/models/volumes.py +++ b/src/dstack/_internal/core/models/volumes.py @@ -2,11 +2,12 @@ from datetime import datetime from enum import Enum from pathlib import PurePosixPath -from typing import Dict, List, Literal, Optional, Tuple, Union +from typing import Any, Dict, List, Literal, Optional, Tuple, Union -from pydantic import Field, validator +from pydantic import Field, ValidationError, validator from typing_extensions import Annotated, Self +from dstack._internal.core.errors import ConfigurationError from dstack._internal.core.models.backends.base import BackendType from dstack._internal.core.models.common import CoreModel from dstack._internal.core.models.profiles import parse_idle_duration @@ -32,29 +33,26 @@ def finished_statuses(cls) -> List["VolumeStatus"]: return [cls.FAILED] -class VolumeConfiguration(CoreModel): +class BaseVolumeConfiguration(CoreModel): type: Literal["volume"] = "volume" + backend: Any + """`backend` is used as a tagged union discriminator. Subclasses must override its type + with `Literal[BackendType.]` annotation. Annotated as `Any` since `BackendType` + triggers type checker error: + > Variable is mutable so its type is invariant + """ name: Annotated[Optional[str], Field(description="The volume name")] = None - backend: Annotated[BackendType, Field(description="The volume backend")] - region: Annotated[str, Field(description="The volume region")] - availability_zone: Annotated[ - Optional[str], Field(description="The volume availability zone") - ] = None size: Annotated[ Optional[Memory], Field(description="The volume size. Must be specified when creating new volumes"), ] = None - volume_id: Annotated[ - Optional[str], - Field(description="The volume ID. Must be specified when registering external volumes"), - ] = None auto_cleanup_duration: Annotated[ Optional[Union[str, int]], Field( description=( "Time to wait after volume is no longer used by any job before deleting it. " "Defaults to keep the volume indefinitely. " - "Use the value 'off' or -1 to disable auto-cleanup." + "Use the value `off` or `-1` to disable auto-cleanup" ) ), ] = None @@ -74,13 +72,126 @@ class VolumeConfiguration(CoreModel): "auto_cleanup_duration", pre=True, allow_reuse=True )(parse_idle_duration) + @property + def external_volume_id(self) -> Optional[str]: + """ + Returns the value of a configuration field denoting a user-provided volume identifier + when an existing volume is registered rather than a new one being created. + """ + return None + + @property + def is_external(self) -> bool: + return self.external_volume_id is not None + @property def size_gb(self) -> int: return int(get_or_error(self.size)) +class VolumeConfigurationWithRegion(BaseVolumeConfiguration): + region: Annotated[str, Field(description="The volume region")] + + +class VolumeConfigurationWithAvailibilityZone(VolumeConfigurationWithRegion): + availability_zone: Annotated[ + Optional[str], Field(description="The volume availability zone") + ] = None + + +class VolumeConfigurationWithVolumeID(BaseVolumeConfiguration): + volume_id: Annotated[ + Optional[str], + Field(description="The volume ID. Must be specified when registering external volumes"), + ] = None + + @property + def external_volume_id(self) -> Optional[str]: + return self.volume_id + + +class AWSVolumeConfiguration( + VolumeConfigurationWithAvailibilityZone, VolumeConfigurationWithVolumeID +): + backend: Annotated[Literal[BackendType.AWS], Field(description="The volume backend")] = ( + BackendType.AWS + ) + + +class GCPVolumeConfiguration( + VolumeConfigurationWithAvailibilityZone, VolumeConfigurationWithVolumeID +): + backend: Annotated[Literal[BackendType.GCP], Field(description="The volume backend")] = ( + BackendType.GCP + ) + + +class RunpodVolumeConfiguration(VolumeConfigurationWithRegion, VolumeConfigurationWithVolumeID): + backend: Annotated[Literal[BackendType.RUNPOD], Field(description="The volume backend")] = ( + BackendType.RUNPOD + ) + availability_zone: Annotated[Optional[str], Field(exclude=True)] = None + """Runpod doesn't have AZs but we accept this field for compatibility with older clients.""" + + +class KubernetesVolumeConfiguration(BaseVolumeConfiguration): + backend: Annotated[ + Literal[BackendType.KUBERNETES], Field(description="The volume backend") + ] = BackendType.KUBERNETES + size: Annotated[ + Optional[Memory], + Field( + description=( + "The requested volume size. Must be specified when creating new PVCs." + " Ignored if `claim_name` is set" + ) + ), + ] = None + """`size` is overridden to provide Kubernetes-specific description. + The signature is the same as in the base class.""" + claim_name: Annotated[ + Optional[str], + Field( + description=( + "The `PersistentVolumeClaim` name. Must be specified when registering" + " the existing PVC instead of creating a new one" + ) + ), + ] = None + storage_class_name: Annotated[ + Optional[str], Field(description="The `StorageClass` name. Ignored if `claim_name` is set") + ] = None + access_modes: Annotated[ + list[str], + Field(description="A list of accepted access modes. Ignored if `claim_name` is set"), + ] = ["ReadWriteOnce"] + + @property + def external_volume_id(self) -> Optional[str]: + return self.claim_name + + +AnyVolumeConfiguration = Union[ + AWSVolumeConfiguration, + GCPVolumeConfiguration, + RunpodVolumeConfiguration, + KubernetesVolumeConfiguration, +] + + +class VolumeConfiguration(CoreModel): + __root__: Annotated[AnyVolumeConfiguration, Field(discriminator="backend")] + + +def parse_volume_configuration(data: dict) -> AnyVolumeConfiguration: + try: + return VolumeConfiguration.parse_obj(data).__root__ + except ValidationError as e: + raise ConfigurationError(e) + + class VolumeSpec(CoreModel): - configuration: VolumeConfiguration + configuration: Annotated[AnyVolumeConfiguration, Field(discriminator="backend")] configuration_path: Optional[str] = None @@ -119,7 +230,7 @@ class Volume(CoreModel): name: str user: str project_name: str - configuration: VolumeConfiguration + configuration: Annotated[AnyVolumeConfiguration, Field(discriminator="backend")] external: bool created_at: datetime last_processed_at: datetime @@ -145,6 +256,28 @@ def get_attachment_data_for_instance(self, instance_id: str) -> Optional[VolumeA # volume was attached before attachments were introduced return self.attachment_data + def get_backend(self) -> BackendType: + return self.configuration.backend + + def get_region(self) -> str: + """ + Returns the volume region or an empty string if the volume (that is, its backend) + has no such thing as a "region". + """ + if isinstance(self.configuration, VolumeConfigurationWithRegion): + return self.configuration.region + return "" + + def get_availability_zone(self) -> Optional[str]: + """ + Returns the volume availability zone or `None` if: + * the volume (that is, its backend) has no such thing as an "availability zone" + * `VolumeProvisioningData` is not set for some reason + """ + if self.provisioning_data is None: + return None + return self.provisioning_data.availability_zone + class VolumePlan(CoreModel): project_name: str diff --git a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py index 0e4358d6f..2811f34ce 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/jobs_submitted.py @@ -1538,8 +1538,8 @@ async def _process_volume_attachments( ): raise ServerClientError("Cannot attach a volume locked for processing") if ( - job_provisioning_data.get_base_backend() != volume.configuration.backend - or job_provisioning_data.region.lower() != volume.configuration.region.lower() + job_provisioning_data.get_base_backend() != volume.get_backend() + or job_provisioning_data.region.lower() != volume.get_region().lower() ): continue if volume.provisioning_data is None or not volume.provisioning_data.attachable: @@ -2217,8 +2217,8 @@ def _get_offer_mount_point_volume( ) -> Volume: for volume in volumes: if ( - volume.configuration.backend != offer.backend - or volume.configuration.region.lower() != offer.region.lower() + volume.get_backend() != offer.backend + or volume.get_region().lower() != offer.region.lower() ): continue return volume diff --git a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py index ea90ab90d..cb8d6ae79 100644 --- a/src/dstack/_internal/server/background/pipeline_tasks/volumes.py +++ b/src/dstack/_internal/server/background/pipeline_tasks/volumes.py @@ -328,7 +328,7 @@ async def _process_submitted_volume(volume_model: VolumeModel) -> _ProcessResult compute = backend.compute() assert isinstance(compute, ComputeWithVolumeSupport) try: - if volume.configuration.volume_id is not None: + if volume.configuration.is_external: logger.info("Registering external volume %s", volume_model.name) vpd = await run_async( compute.register_volume, diff --git a/src/dstack/_internal/server/schemas/volumes.py b/src/dstack/_internal/server/schemas/volumes.py index 1a63c49b9..ff1b106f9 100644 --- a/src/dstack/_internal/server/schemas/volumes.py +++ b/src/dstack/_internal/server/schemas/volumes.py @@ -1,11 +1,11 @@ from datetime import datetime -from typing import List, Optional +from typing import Annotated, List, Optional from uuid import UUID from pydantic import Field from dstack._internal.core.models.common import CoreModel -from dstack._internal.core.models.volumes import VolumeConfiguration +from dstack._internal.core.models.volumes import AnyVolumeConfiguration class ListVolumesRequest(CoreModel): @@ -22,7 +22,7 @@ class GetVolumeRequest(CoreModel): class CreateVolumeRequest(CoreModel): - configuration: VolumeConfiguration + configuration: Annotated[AnyVolumeConfiguration, Field(discriminator="backend")] class DeleteVolumesRequest(CoreModel): diff --git a/src/dstack/_internal/server/services/instances.py b/src/dstack/_internal/server/services/instances.py index 23b42520d..a311d7d95 100644 --- a/src/dstack/_internal/server/services/instances.py +++ b/src/dstack/_internal/server/services/instances.py @@ -406,23 +406,25 @@ def filter_instances( volumes: Optional[List[List[Volume]]] = None, shared: bool = False, ) -> List[InstanceModel]: - backend_types = profile.backends - regions = profile.regions - zones = profile.availability_zones + backend_types: Optional[list[BackendType]] = profile.backends + regions: Optional[list[str]] = profile.regions + zones: Optional[list[str]] = profile.availability_zones + # (BackendType, region.lower() | "", availability_zone.lower() | None) + volumes_locations: Optional[set[tuple[BackendType, str, Optional[str]]]] = None if volumes: - mount_point_volumes = volumes[0] - backend_types = [v.configuration.backend for v in mount_point_volumes] - regions = [v.configuration.region for v in mount_point_volumes] - volume_zones = [ - v.provisioning_data.availability_zone - for v in mount_point_volumes - if v.provisioning_data is not None - and v.provisioning_data.availability_zone is not None - ] - if zones is None: - zones = volume_zones - zones = [z for z in zones if z in volume_zones] + volumes_locations = set() + for volume in volumes[0]: + volume_backend = volume.get_backend() + volume_region = volume.get_region().lower() + # If the volume has an AZ, it's added twice -- with and without an AZ. + # When the instance location is checked against the available volumes locations (see + # below) the instance with an AZ matches only the volume with the same AZ, while + # the instance without an AZ matches any volume with the same region regardless of AZs. + # This reflects the logic used before this stricter volumes_locations check was added. + volumes_locations.add((volume_backend, volume_region, None)) + if (volume_zone := volume.get_availability_zone()) is not None: + volumes_locations.add((volume_backend, volume_region, volume_zone.lower())) if multinode: if backend_types is None: @@ -465,6 +467,17 @@ def filter_instances( requirements=requirements, ): continue + if volumes_locations is not None: + jpd = get_instance_provisioning_data(instance) + # instance_matches_constraints() also skips filtering if JPD is not set + if jpd is not None: + instance_backend = jpd.get_base_backend() + instance_region = jpd.region.lower() + instance_zone = jpd.availability_zone + if instance_zone is not None: + instance_zone = instance_zone.lower() + if (instance_backend, instance_region, instance_zone) not in volumes_locations: + continue filtered_instances.append(instance) return filtered_instances diff --git a/src/dstack/_internal/server/services/jobs/__init__.py b/src/dstack/_internal/server/services/jobs/__init__.py index e37c3a871..32828cf3d 100644 --- a/src/dstack/_internal/server/services/jobs/__init__.py +++ b/src/dstack/_internal/server/services/jobs/__init__.py @@ -447,18 +447,12 @@ def check_can_attach_job_volumes(volumes: List[List[Volume]]): """ if len(volumes) == 0: return - expected_backends = {v.configuration.backend for v in volumes[0]} - expected_regions = {v.configuration.region for v in volumes[0]} + expected_locations = {(v.get_backend(), v.get_region().lower()) for v in volumes[0]} for mount_point_volumes in volumes: - backends = {v.configuration.backend for v in mount_point_volumes} - regions = {v.configuration.region for v in mount_point_volumes} - if backends != expected_backends: + locations = {(v.get_backend(), v.get_region().lower()) for v in mount_point_volumes} + if locations != expected_locations: raise ServerClientError( - "Volumes from different backends specified for different mount points" - ) - if regions != expected_regions: - raise ServerClientError( - "Volumes from different regions specified for different mount points" + "Volumes from different locations specified for different mount points" ) for volume in mount_point_volumes: if volume.status != VolumeStatus.ACTIVE: @@ -557,16 +551,14 @@ def _get_job_mount_point_attached_volume( """ for volume in volumes: if ( - volume.configuration.backend != job_provisioning_data.get_base_backend() - or volume.configuration.region.lower() != job_provisioning_data.region.lower() + volume.get_backend() != job_provisioning_data.get_base_backend() + or volume.get_region().lower() != job_provisioning_data.region.lower() ): continue if ( - volume.provisioning_data is not None - and volume.provisioning_data.availability_zone is not None + (volume_availability_zone := volume.get_availability_zone()) is not None and job_provisioning_data.availability_zone is not None - and volume.provisioning_data.availability_zone.lower() - != job_provisioning_data.availability_zone.lower() + and volume_availability_zone.lower() != job_provisioning_data.availability_zone.lower() ): continue return volume diff --git a/src/dstack/_internal/server/services/offers.py b/src/dstack/_internal/server/services/offers.py index b00e3b382..3ac8b7ed6 100644 --- a/src/dstack/_internal/server/services/offers.py +++ b/src/dstack/_internal/server/services/offers.py @@ -1,5 +1,5 @@ import itertools -from collections.abc import Iterable, Iterator +from collections.abc import Container, Iterable, Iterator from typing import List, Literal, Optional, Tuple, Union import gpuhunt @@ -42,21 +42,15 @@ async def get_offers_by_requirements( ) -> List[Tuple[Backend, InstanceOfferWithAvailability]]: backends: List[Backend] = await backends_services.get_project_backends(project=project) - backend_types = profile.backends - regions = profile.regions - availability_zones = profile.availability_zones - instance_types = profile.instance_types + backend_types: Optional[list[BackendType]] = profile.backends + regions: Optional[list[str]] = profile.regions + availability_zones: Optional[list[str]] = profile.availability_zones + instance_types: Optional[list[str]] = profile.instance_types + # (BackendType, region.lower() | "") + volumes_locations: Optional[set[tuple[BackendType, str]]] = None if volumes: - mount_point_volumes = volumes[0] - volumes_backend_types = [v.configuration.backend for v in mount_point_volumes] - if backend_types is None: - backend_types = volumes_backend_types - backend_types = [b for b in backend_types if b in volumes_backend_types] - volumes_regions = [v.configuration.region for v in mount_point_volumes] - if regions is None: - regions = volumes_regions - regions = [r for r in regions if r in volumes_regions] + volumes_locations = {(v.get_backend(), v.get_region().lower()) for v in volumes[0]} if multinode: if backend_types is None: @@ -107,6 +101,7 @@ async def get_offers_by_requirements( availability_zones=availability_zones, instance_types=instance_types, placement_group=placement_group, + volumes_locations=volumes_locations, ) if blocks != 1: @@ -192,6 +187,7 @@ def _filter_offers( availability_zones: Optional[List[str]] = None, instance_types: Optional[List[str]] = None, placement_group: Optional[PlacementGroup] = None, + volumes_locations: Optional[Container[tuple[BackendType, str]]] = None, ) -> Iterator[Tuple[Backend, InstanceOfferWithAvailability]]: """ Yields filtered offers. May return modified offers to match the filters. @@ -224,6 +220,13 @@ def _filter_offers( if not new_offer.availability_zones: continue offer = new_offer + # Offer is futher filtered against volumes AZs in Compute implementation, see + # ComputeWithCreateInstanceSupport._restrict_instance_offer_az_to_volumes_az() + if ( + volumes_locations is not None + and (offer.backend, offer.region.lower()) not in volumes_locations + ): + continue yield (b, offer) diff --git a/src/dstack/_internal/server/services/volumes.py b/src/dstack/_internal/server/services/volumes.py index 8638009ec..9ec85ad8d 100644 --- a/src/dstack/_internal/server/services/volumes.py +++ b/src/dstack/_internal/server/services/volumes.py @@ -16,6 +16,7 @@ ) from dstack._internal.core.models.profiles import parse_duration from dstack._internal.core.models.volumes import ( + AnyVolumeConfiguration, Volume, VolumeAttachment, VolumeAttachmentData, @@ -259,7 +260,7 @@ async def create_volume( session: AsyncSession, project: ProjectModel, user: UserModel, - configuration: VolumeConfiguration, + configuration: AnyVolumeConfiguration, pipeline_hinter: PipelineHinterProtocol, ) -> Volume: spec = await apply_plugin_policies( @@ -399,7 +400,7 @@ def volume_model_to_volume(volume_model: VolumeModel) -> Volume: project_name=volume_model.project.name, user=volume_model.user.name, configuration=configuration, - external=configuration.volume_id is not None, + external=configuration.is_external, created_at=volume_model.created_at, last_processed_at=volume_model.last_processed_at, status=volume_model.status, @@ -416,8 +417,8 @@ def volume_model_to_volume(volume_model: VolumeModel) -> Volume: return volume -def get_volume_configuration(volume_model: VolumeModel) -> VolumeConfiguration: - return VolumeConfiguration.__response__.parse_raw(volume_model.configuration) +def get_volume_configuration(volume_model: VolumeModel) -> AnyVolumeConfiguration: + return VolumeConfiguration.__response__.parse_raw(volume_model.configuration).__root__ def get_volume_provisioning_data(volume_model: VolumeModel) -> Optional[VolumeProvisioningData]: @@ -467,9 +468,9 @@ async def generate_volume_name(session: AsyncSession, project: ProjectModel) -> return name -def _validate_volume_configuration(configuration: VolumeConfiguration): - if configuration.volume_id is None and configuration.size is None: - raise ServerClientError("Volume must specify either volume_id or size") +def _validate_volume_configuration(configuration: AnyVolumeConfiguration): + if configuration.external_volume_id is None and configuration.size is None: + raise ServerClientError("Volume must specify either existing identifier or size") backends_services.check_backend_type_available(configuration.backend) if configuration.backend not in BACKENDS_WITH_VOLUMES_SUPPORT: raise ServerClientError( @@ -479,7 +480,7 @@ def _validate_volume_configuration(configuration: VolumeConfiguration): if configuration.name is not None: validate_dstack_resource_name(configuration.name) - if configuration.volume_id is not None and configuration.auto_cleanup_duration is not None: + if configuration.is_external and configuration.auto_cleanup_duration is not None: if ( isinstance(configuration.auto_cleanup_duration, int) and configuration.auto_cleanup_duration > 0 @@ -488,7 +489,7 @@ def _validate_volume_configuration(configuration: VolumeConfiguration): and configuration.auto_cleanup_duration not in ("off", "-1") ): raise ServerClientError( - "External volumes (with volume_id) do not support auto_cleanup_duration. " + "External volumes do not support auto_cleanup_duration. " "Auto-cleanup only works for volumes created and managed by dstack." ) @@ -546,6 +547,6 @@ def _get_volume_cost(volume: Volume) -> float: ) -def _get_autocleanup_enabled(configuration: VolumeConfiguration) -> bool: +def _get_autocleanup_enabled(configuration: AnyVolumeConfiguration) -> bool: auto_cleanup_duration = parse_duration(configuration.auto_cleanup_duration) return auto_cleanup_duration is not None and auto_cleanup_duration > 0 diff --git a/src/dstack/_internal/server/testing/common.py b/src/dstack/_internal/server/testing/common.py index 99e147d68..c955f5ae1 100644 --- a/src/dstack/_internal/server/testing/common.py +++ b/src/dstack/_internal/server/testing/common.py @@ -85,6 +85,8 @@ ) from dstack._internal.core.models.users import GlobalRole from dstack._internal.core.models.volumes import ( + AnyVolumeConfiguration, + KubernetesVolumeConfiguration, Volume, VolumeAttachment, VolumeConfiguration, @@ -472,6 +474,7 @@ def get_job_provisioning_data( dockerized: bool = False, backend: BackendType = BackendType.AWS, region: str = "us-east-1", + availability_zone: Optional[str] = None, gpu_count: int = 0, gpu_memory_gib: float = 16, gpu_name: str = "T4", @@ -507,6 +510,7 @@ def get_job_provisioning_data( hostname=hostname, internal_ip=internal_ip, region=region, + availability_zone=availability_zone, price=price, username=username, ssh_port=ssh_port, @@ -795,7 +799,8 @@ async def create_instance( backend: BackendType = BackendType.VERDA, termination_policy: Optional[TerminationPolicy] = None, termination_idle_time: int = DEFAULT_FLEET_TERMINATION_IDLE_TIME, - region: str = "eu-west", + region: Optional[str] = None, + availability_zone: Optional[str] = None, remote_connection_info: Optional[RemoteConnectionInfo] = None, offer: Optional[Union[InstanceOfferWithAvailability, Literal["auto"]]] = "auto", job_provisioning_data: Optional[Union[JobProvisioningData, Literal["auto"]]] = "auto", @@ -808,11 +813,14 @@ async def create_instance( ) -> InstanceModel: if instance_id is None: instance_id = uuid.uuid4() + if region is None: + region = "" if backend == BackendType.KUBERNETES else "eu-west" if job_provisioning_data == "auto": job_provisioning_data = get_job_provisioning_data( dockerized=True, backend=backend, region=region, + availability_zone=availability_zone, spot=spot, hostname="running_instance.ip", internal_ip=None, @@ -997,7 +1005,7 @@ async def create_volume( created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), last_processed_at: Optional[datetime] = None, last_job_processed_at: Optional[datetime] = None, - configuration: Optional[VolumeConfiguration] = None, + configuration: Optional[AnyVolumeConfiguration] = None, volume_provisioning_data: Optional[VolumeProvisioningData] = None, deleted_at: Optional[datetime] = None, backend: BackendType = BackendType.AWS, @@ -1033,7 +1041,7 @@ def get_volume( name: str = "test_volume", user: str = "test_user", project_name: str = "test_project", - configuration: Optional[VolumeConfiguration] = None, + configuration: Optional[AnyVolumeConfiguration] = None, external: bool = False, created_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), last_processed_at: datetime = datetime(2023, 1, 2, 3, 4, tzinfo=timezone.utc), @@ -1077,13 +1085,33 @@ def get_volume_configuration( size: Optional[Memory] = Memory(100), volume_id: Optional[str] = None, auto_cleanup_duration: Optional[Union[str, int]] = None, -) -> VolumeConfiguration: - return VolumeConfiguration( +) -> AnyVolumeConfiguration: + assert backend != BackendType.KUBERNETES, "use get_kubernetes_volume_configuration() instead" + return VolumeConfiguration.parse_obj( + dict( + name=name, + backend=backend, + region=region, + size=size, + volume_id=volume_id, + auto_cleanup_duration=auto_cleanup_duration, + ) + ).__root__ + + +def get_kubernetes_volume_configuration( + name: str = "test-volume", + size: Optional[Memory] = Memory(100), + claim_name: Optional[str] = None, + auto_cleanup_duration: Optional[Union[str, int]] = None, + storage_class_name: Optional[str] = None, +) -> KubernetesVolumeConfiguration: + return KubernetesVolumeConfiguration( name=name, - backend=backend, - region=region, + backend=BackendType.KUBERNETES, size=size, - volume_id=volume_id, + claim_name=claim_name, + storage_class_name=storage_class_name, auto_cleanup_duration=auto_cleanup_duration, ) diff --git a/src/dstack/api/server/_volumes.py b/src/dstack/api/server/_volumes.py index 502007e80..5cf56afc3 100644 --- a/src/dstack/api/server/_volumes.py +++ b/src/dstack/api/server/_volumes.py @@ -3,7 +3,7 @@ from pydantic import parse_obj_as from dstack._internal.core.compatibility.volumes import get_create_volume_excludes -from dstack._internal.core.models.volumes import Volume, VolumeConfiguration +from dstack._internal.core.models.volumes import AnyVolumeConfiguration, Volume from dstack._internal.server.schemas.volumes import ( CreateVolumeRequest, DeleteVolumesRequest, @@ -25,7 +25,7 @@ def get(self, project_name: str, name: str) -> Volume: def create( self, project_name: str, - configuration: VolumeConfiguration, + configuration: AnyVolumeConfiguration, ) -> Volume: body = CreateVolumeRequest(configuration=configuration) resp = self._request( diff --git a/src/tests/_internal/core/backends/kubernetes/test_resources.py b/src/tests/_internal/core/backends/kubernetes/test_resources.py index 6a74233a5..1839a4191 100644 --- a/src/tests/_internal/core/backends/kubernetes/test_resources.py +++ b/src/tests/_internal/core/backends/kubernetes/test_resources.py @@ -6,6 +6,8 @@ from dstack._internal.core.backends.kubernetes.resources import ( get_amd_gpu_from_node_labels, get_nvidia_gpu_from_node_labels, + validate_label_key, + validate_label_value, ) from dstack._internal.core.models.instances import Gpu @@ -51,3 +53,61 @@ def test_returns_none_if_multiple_gpu_models(self, caplog: pytest.LogCaptureFixt labels = {"beta.amd.com/gpu.device-id.74b5": "4", "beta.amd.com/gpu.device-id.74a5": "4"} assert get_amd_gpu_from_node_labels(labels) is None assert "Multiple AMD GPU models detected" in caplog.text + + +class TestLabelValidation: + @pytest.mark.parametrize( + "key", + [ + pytest.param("env", id="private"), + pytest.param("k8s.example.com/Valid.Label_Name-1", id="prefixed"), + ], + ) + def test_valid_key(self, key: str): + validate_label_key(key) + + @pytest.mark.parametrize( + ["key", "expected_error"], + [ + pytest.param("app.kubernetes.io//name", "Too many segments", id="too-many-segments"), + pytest.param("/name", "Empty prefix", id="empty-prefix"), + pytest.param("a" * 254 + "/name", "Prefix too long", id="too-long-prefix"), + pytest.param("invalid prefix/name", "Invalid prefix", id="space-in-prefix"), + pytest.param("my_app/name", "Invalid prefix", id="underscore-in-prefix"), + pytest.param("-invalid/name", "Invalid prefix", id="leading-dash-in-prefix"), + pytest.param("invalid-/name", "Invalid prefix", id="trailing-dash-in-prefix"), + pytest.param("Invalid/name", "Invalid prefix", id="uppercase-in-prefix"), + pytest.param("", "Empty name", id="empty-name-no-prefix"), + pytest.param("prefix/", "Empty name", id="empty-name-with-prefix"), + pytest.param("a" * 64, "Name too long", id="too-long-name-no-prefix"), + pytest.param("prefix/" + "a" * 64, "Name too long", id="too-long-name-with-prefix"), + pytest.param("-name", "Invalid name", id="leading-dash-in-name"), + pytest.param("name-", "Invalid name", id="trailing-dash-in-name"), + ], + ) + def test_invalid_key(self, key: str, expected_error: str): + with pytest.raises(ValueError, match=expected_error): + validate_label_key(key) + + @pytest.mark.parametrize( + "value", + [ + pytest.param("", id="empty"), + pytest.param("Valid.Label_Value-1", id="non-empty"), + ], + ) + def test_valid_value(self, value: str): + validate_label_value(value) + + @pytest.mark.parametrize( + ["value", "expected_error"], + [ + pytest.param("a" * 64, "Value too long", id="too-long"), + pytest.param("invalid value", "Invalid value", id="space"), + pytest.param("-invalid", "Invalid value", id="leading-dash"), + pytest.param("invalid-", "Invalid value", id="trailing-dash"), + ], + ) + def test_invalid_value(self, value: str, expected_error: str): + with pytest.raises(ValueError, match=expected_error): + validate_label_value(value) diff --git a/src/tests/_internal/server/services/test_instances.py b/src/tests/_internal/server/services/test_instances.py index c1960102f..1f5e1fb52 100644 --- a/src/tests/_internal/server/services/test_instances.py +++ b/src/tests/_internal/server/services/test_instances.py @@ -26,8 +26,10 @@ create_repo, create_run, create_user, + get_kubernetes_volume_configuration, get_volume, get_volume_configuration, + get_volume_provisioning_data, list_events, ) from dstack._internal.utils.common import get_current_datetime @@ -160,6 +162,77 @@ async def test_returns_volume_instances(self, test_db, session: AsyncSession): ) assert res == [runpod_instance2] + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_volume_instances_with_az(self, test_db, session: AsyncSession): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + aws_instance_1 = await create_instance( + session=session, + project=project, + backend=BackendType.AWS, + region="us-1", + availability_zone="us-1a", + ) + aws_instance_2 = await create_instance( + session=session, + project=project, + backend=BackendType.AWS, + region="us-1", + availability_zone="us-1b", + ) + gcp_instance = await create_instance( + session=session, + project=project, + backend=BackendType.GCP, + region="us-1", + availability_zone="us-1b", + ) + instances = [aws_instance_1, aws_instance_2, gcp_instance] + volume = get_volume( + configuration=get_volume_configuration(backend=BackendType.AWS, region="us-1"), + provisioning_data=get_volume_provisioning_data( + backend=BackendType.AWS, availability_zone="us-1b" + ), + ) + res = instances_services.filter_instances( + instances=instances, + profile=Profile(name="test"), + volumes=[[volume]], + ) + assert res == [aws_instance_2] + + @pytest.mark.asyncio + @pytest.mark.parametrize("test_db", ["sqlite", "postgres"], indirect=True) + async def test_returns_volume_instances_without_region(self, test_db, session: AsyncSession): + user = await create_user(session=session) + project = await create_project(session=session, owner=user) + aws_instance = await create_instance( + session=session, + project=project, + backend=BackendType.AWS, + ) + # Kubernetes does not support "create instance" feature, but for the sake of this test + # it does not matter + kubernetes_instance = await create_instance( + session=session, + project=project, + backend=BackendType.KUBERNETES, + ) + instances = [aws_instance, kubernetes_instance] + volume = get_volume( + configuration=get_kubernetes_volume_configuration(), + provisioning_data=get_volume_provisioning_data( + backend=BackendType.KUBERNETES, availability_zone=None + ), + ) + res = instances_services.filter_instances( + instances=instances, + profile=Profile(name="test"), + volumes=[[volume]], + ) + assert res == [kubernetes_instance] + @pytest.mark.asyncio @pytest.mark.usefixtures("image_config_mock") diff --git a/src/tests/_internal/server/services/test_offers.py b/src/tests/_internal/server/services/test_offers.py index 685369fc4..25ce8021a 100644 --- a/src/tests/_internal/server/services/test_offers.py +++ b/src/tests/_internal/server/services/test_offers.py @@ -9,6 +9,7 @@ from dstack._internal.server.services.offers import get_offers_by_requirements from dstack._internal.server.testing.common import ( get_instance_offer_with_availability, + get_kubernetes_volume_configuration, get_volume, get_volume_configuration, ) @@ -99,6 +100,35 @@ async def test_returns_volume_offers(self): m.assert_awaited_once() assert res == [(runpod_backend_mock, runpod_offer2)] + @pytest.mark.asyncio + async def test_returns_volume_offers_without_region(self): + profile = Profile(name="test") + requirements = Requirements(resources=ResourcesSpec()) + with patch("dstack._internal.server.services.backends.get_project_backends") as m: + aws_backend_mock = Mock() + aws_backend_mock.TYPE = BackendType.AWS + aws_offer = get_instance_offer_with_availability(backend=BackendType.AWS) + aws_backend_mock.compute.return_value.get_offers.return_value = [aws_offer] + kubernetes_backend_mock = Mock() + kubernetes_backend_mock.TYPE = BackendType.KUBERNETES + kubernetes_offer = get_instance_offer_with_availability( + backend=BackendType.KUBERNETES, + region="", + availability_zones=None, + ) + kubernetes_backend_mock.compute.return_value.get_offers.return_value = [ + kubernetes_offer + ] + m.return_value = [aws_backend_mock, kubernetes_backend_mock] + res = await get_offers_by_requirements( + project=Mock(), + profile=profile, + requirements=requirements, + volumes=[[get_volume(configuration=get_kubernetes_volume_configuration())]], + ) + m.assert_awaited_once() + assert res == [(kubernetes_backend_mock, kubernetes_offer)] + @pytest.mark.asyncio async def test_returns_az_offers(self): profile = Profile(name="test", availability_zones=["az1", "az3"]) diff --git a/src/tests/_internal/server/services/test_volumes.py b/src/tests/_internal/server/services/test_volumes.py index 6bfb9bae6..82477812a 100644 --- a/src/tests/_internal/server/services/test_volumes.py +++ b/src/tests/_internal/server/services/test_volumes.py @@ -5,7 +5,7 @@ from dstack._internal.core.errors import ServerClientError from dstack._internal.core.models.backends.base import BackendType -from dstack._internal.core.models.volumes import VolumeConfiguration, VolumeStatus +from dstack._internal.core.models.volumes import AWSVolumeConfiguration, VolumeStatus from dstack._internal.server.services.volumes import ( _get_volume_cost, _validate_volume_configuration, @@ -19,7 +19,7 @@ class TestValidateVolumeConfiguration: def test_external_volume_with_auto_cleanup_duration_raises_error(self): """External volumes (with volume_id) should not allow auto_cleanup_duration""" - config = VolumeConfiguration( + config = AWSVolumeConfiguration( backend=BackendType.AWS, region="us-east-1", volume_id="vol-123456", @@ -32,7 +32,7 @@ def test_external_volume_with_auto_cleanup_duration_raises_error(self): def test_external_volume_with_auto_cleanup_duration_int_raises_error(self): """External volumes with integer auto_cleanup_duration should also raise error""" - config = VolumeConfiguration( + config = AWSVolumeConfiguration( backend=BackendType.AWS, region="us-east-1", volume_id="vol-123456", @@ -45,13 +45,13 @@ def test_external_volume_with_auto_cleanup_duration_int_raises_error(self): def test_external_volume_with_auto_cleanup_disabled_succeeds(self): """External volumes with auto_cleanup_duration='off' or -1 should be allowed""" - config1 = VolumeConfiguration( + config1 = AWSVolumeConfiguration( backend=BackendType.AWS, region="us-east-1", volume_id="vol-123456", auto_cleanup_duration="off", ) - config2 = VolumeConfiguration( + config2 = AWSVolumeConfiguration( backend=BackendType.AWS, region="us-east-1", volume_id="vol-123456", @@ -63,7 +63,7 @@ def test_external_volume_with_auto_cleanup_disabled_succeeds(self): def test_external_volume_without_auto_cleanup_succeeds(self): """External volumes without auto_cleanup_duration should be allowed""" - config = VolumeConfiguration( + config = AWSVolumeConfiguration( backend=BackendType.AWS, region="us-east-1", volume_id="vol-123456" ) # Should not raise any errors @@ -71,7 +71,7 @@ def test_external_volume_without_auto_cleanup_succeeds(self): def test_new_volume_with_auto_cleanup_duration_succeeds(self): """New volumes (without volume_id) with auto_cleanup_duration should be allowed""" - config = VolumeConfiguration( + config = AWSVolumeConfiguration( backend=BackendType.AWS, region="us-east-1", size=100, auto_cleanup_duration="1h" ) # Should not raise any errors