Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 2 additions & 54 deletions src/dstack/_internal/core/backends/verda/compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -251,65 +251,13 @@ def _create_startup_script(client: VerdaClient, name: str, script: str) -> str:
def _delete_startup_script(client: VerdaClient, startup_script_id: Optional[str]) -> None:
if startup_script_id is None:
return
try:
client.startup_scripts.delete_by_id(startup_script_id)
except APIException as e:
if _is_startup_script_not_found_error(e):
logger.debug(
"Skipping startup script %s deletion. Startup script not found.",
startup_script_id,
)
return
raise
client.startup_scripts.delete_by_id(startup_script_id)


def _delete_ssh_keys(client: VerdaClient, ssh_key_ids: Optional[List[str]]) -> None:
if not ssh_key_ids:
return
for ssh_key_id in ssh_key_ids:
_delete_ssh_key(client, ssh_key_id)


def _delete_ssh_key(client: VerdaClient, ssh_key_id: str) -> None:
try:
client.ssh_keys.delete_by_id(ssh_key_id)
except APIException as e:
if _is_ssh_key_not_found_error(e):
logger.debug("Skipping ssh key %s deletion. SSH key not found.", ssh_key_id)
return
raise


def _is_ssh_key_not_found_error(error: APIException) -> bool:
code = (error.code or "").lower()
message = (error.message or "").lower()
if code == "not_found":
return True
if code not in {"", "invalid_request"}:
return False
return (
message == "invalid ssh-key id"
or message == "invalid ssh key id"
or message == "not found"
or ("ssh-key id" in message and "invalid" in message)
or ("ssh key id" in message and "invalid" in message)
)


def _is_startup_script_not_found_error(error: APIException) -> bool:
code = (error.code or "").lower()
message = (error.message or "").lower()
if code == "not_found":
return True
if code not in {"", "invalid_request"}:
return False
return (
message == "invalid startup script id"
or message == "invalid script id"
or message == "not found"
or ("startup script id" in message and "invalid" in message)
or ("script id" in message and "invalid" in message)
)
client.ssh_keys.delete(ssh_key_ids)


def _get_instance_by_id(
Expand Down
129 changes: 5 additions & 124 deletions src/tests/_internal/core/backends/verda/test_compute.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,6 @@
VerdaInstanceBackendData,
_create_ssh_key,
_create_startup_script,
_is_ssh_key_not_found_error,
_is_startup_script_not_found_error,
)
from dstack._internal.core.errors import BackendError, NoCapacityError

Expand Down Expand Up @@ -302,7 +300,7 @@ def test_terminate_instance_without_backend_data(self):

_assert_terminate_call(compute.client.instances.action)
compute.client.startup_scripts.delete_by_id.assert_not_called()
compute.client.ssh_keys.delete_by_id.assert_not_called()
compute.client.ssh_keys.delete.assert_not_called()

def test_terminate_instance_deletes_startup_script(self):
compute = VerdaCompute.__new__(VerdaCompute)
Expand All @@ -316,7 +314,7 @@ def test_terminate_instance_deletes_startup_script(self):

_assert_terminate_call(compute.client.instances.action)
compute.client.startup_scripts.delete_by_id.assert_called_once_with("script-id")
assert compute.client.ssh_keys.delete_by_id.call_count == 2
compute.client.ssh_keys.delete.assert_called_once_with(["ssh-key-id-1", "ssh-key-id-2"])

def test_terminate_instance_still_deletes_script_when_instance_is_missing(self):
compute = VerdaCompute.__new__(VerdaCompute)
Expand All @@ -330,41 +328,7 @@ def test_terminate_instance_still_deletes_script_when_instance_is_missing(self):
compute.terminate_instance("instance-id", "FIN-01", backend_data)

compute.client.startup_scripts.delete_by_id.assert_called_once_with("script-id")
compute.client.ssh_keys.delete_by_id.assert_called_once_with("ssh-key-id-1")

def test_terminate_instance_ignores_missing_startup_script(self):
compute = VerdaCompute.__new__(VerdaCompute)
compute.client = MagicMock()
compute.client.startup_scripts.delete_by_id.side_effect = APIException(
"",
"Invalid startup script id",
)
backend_data = VerdaInstanceBackendData(
startup_script_id="script-id",
ssh_key_ids=["ssh-key-id-1"],
).json()

compute.terminate_instance("instance-id", "FIN-01", backend_data)

_assert_terminate_call(compute.client.instances.action)
compute.client.ssh_keys.delete_by_id.assert_called_once_with("ssh-key-id-1")

def test_terminate_instance_ignores_missing_startup_script_invalid_script_id(self):
compute = VerdaCompute.__new__(VerdaCompute)
compute.client = MagicMock()
compute.client.startup_scripts.delete_by_id.side_effect = APIException(
"invalid_request",
"Invalid script ID",
)
backend_data = VerdaInstanceBackendData(
startup_script_id="script-id",
ssh_key_ids=["ssh-key-id-1"],
).json()

compute.terminate_instance("instance-id", "FIN-01", backend_data)

_assert_terminate_call(compute.client.instances.action)
compute.client.ssh_keys.delete_by_id.assert_called_once_with("ssh-key-id-1")
compute.client.ssh_keys.delete.assert_called_once_with(["ssh-key-id-1"])

def test_terminate_instance_retries_on_script_delete_error(self):
compute = VerdaCompute.__new__(VerdaCompute)
Expand All @@ -380,99 +344,16 @@ def test_terminate_instance_retries_on_script_delete_error(self):
with pytest.raises(APIException):
compute.terminate_instance("instance-id", "FIN-01", backend_data)

compute.client.ssh_keys.delete_by_id.assert_not_called()

def test_terminate_instance_ignores_missing_ssh_key(self):
compute = VerdaCompute.__new__(VerdaCompute)
compute.client = MagicMock()
compute.client.ssh_keys.delete_by_id.side_effect = APIException(
"invalid_request",
"Invalid ssh-key ID",
)
backend_data = VerdaInstanceBackendData(
startup_script_id="script-id",
ssh_key_ids=["ssh-key-id-1"],
).json()

compute.terminate_instance("instance-id", "FIN-01", backend_data)

_assert_terminate_call(compute.client.instances.action)
compute.client.startup_scripts.delete_by_id.assert_called_once_with("script-id")
compute.client.ssh_keys.delete_by_id.assert_called_once_with("ssh-key-id-1")

def test_terminate_instance_deletes_remaining_ssh_keys_when_one_missing(self):
compute = VerdaCompute.__new__(VerdaCompute)
compute.client = MagicMock()
compute.client.ssh_keys.delete_by_id.side_effect = [
APIException("invalid_request", "Invalid ssh-key ID"),
None,
]
backend_data = VerdaInstanceBackendData(
startup_script_id="script-id",
ssh_key_ids=["ssh-key-id-1", "ssh-key-id-2"],
).json()

compute.terminate_instance("instance-id", "FIN-01", backend_data)

compute.client.startup_scripts.delete_by_id.assert_called_once_with("script-id")
compute.client.ssh_keys.delete_by_id.assert_any_call("ssh-key-id-1")
compute.client.ssh_keys.delete_by_id.assert_any_call("ssh-key-id-2")
assert compute.client.ssh_keys.delete_by_id.call_count == 2
compute.client.ssh_keys.delete.assert_not_called()

def test_terminate_instance_retries_on_ssh_key_delete_error(self):
compute = VerdaCompute.__new__(VerdaCompute)
compute.client = MagicMock()
compute.client.ssh_keys.delete_by_id.side_effect = APIException("", "Random API error")
compute.client.ssh_keys.delete.side_effect = APIException("", "Random API error")
backend_data = VerdaInstanceBackendData(
startup_script_id="script-id",
ssh_key_ids=["ssh-key-id-1"],
).json()

with pytest.raises(APIException):
compute.terminate_instance("instance-id", "FIN-01", backend_data)


class TestIsStartupScriptNotFoundError:
def test_returns_true_for_not_found_code_even_with_custom_message(self):
assert _is_startup_script_not_found_error(
APIException("not_found", "Startup script does not exist anymore")
)

def test_returns_true_for_invalid_script_id(self):
assert _is_startup_script_not_found_error(
APIException("invalid_request", "Invalid script ID")
)

def test_returns_true_for_not_found(self):
assert _is_startup_script_not_found_error(APIException("not_found", "Not Found"))

def test_returns_false_for_unrelated_error(self):
assert not _is_startup_script_not_found_error(
APIException("forbidden", "Permission denied")
)

def test_returns_false_for_unrelated_invalid_request(self):
assert not _is_startup_script_not_found_error(
APIException("invalid_request", "Some other invalid request")
)


class TestIsSSHKeyNotFoundError:
def test_returns_true_for_not_found_code_even_with_custom_message(self):
assert _is_ssh_key_not_found_error(
APIException("not_found", "SSH key does not exist anymore")
)

def test_returns_true_for_invalid_ssh_key_id(self):
assert _is_ssh_key_not_found_error(APIException("invalid_request", "Invalid ssh-key ID"))

def test_returns_true_for_not_found(self):
assert _is_ssh_key_not_found_error(APIException("not_found", "Not Found"))

def test_returns_false_for_unrelated_error(self):
assert not _is_ssh_key_not_found_error(APIException("forbidden", "Permission denied"))

def test_returns_false_for_unrelated_invalid_request(self):
assert not _is_ssh_key_not_found_error(
APIException("invalid_request", "Some other invalid request")
)
Loading