From 754b956cb0b35cecce072ec560a077bab3f3d795 Mon Sep 17 00:00:00 2001 From: alcholiclg Date: Wed, 22 Apr 2026 04:35:23 +0800 Subject: [PATCH] support acp and a2a --- .gitignore | 1 + ms_agent/a2a/__init__.py | 63 ++ ms_agent/a2a/agent_card.py | 114 ++++ ms_agent/a2a/client.py | 129 ++++ ms_agent/a2a/errors.py | 80 +++ ms_agent/a2a/executor.py | 173 ++++++ ms_agent/a2a/session_store.py | 166 +++++ ms_agent/a2a/translator.py | 109 ++++ ms_agent/acp/__init__.py | 20 + ms_agent/acp/client.py | 132 ++++ ms_agent/acp/config.py | 76 +++ ms_agent/acp/errors.py | 77 +++ ms_agent/acp/http_adapter.py | 268 ++++++++ ms_agent/acp/permissions.py | 114 ++++ ms_agent/acp/proxy.py | 385 ++++++++++++ ms_agent/acp/proxy_session.py | 193 ++++++ ms_agent/acp/registry.py | 62 ++ ms_agent/acp/server.py | 398 ++++++++++++ ms_agent/acp/session_store.py | 183 ++++++ ms_agent/acp/translator.py | 466 ++++++++++++++ ms_agent/cli/a2a_cmd.py | 203 ++++++ ms_agent/cli/acp_cmd.py | 198 ++++++ ms_agent/cli/acp_proxy_cmd.py | 50 ++ ms_agent/cli/cli.py | 8 + ms_agent/tools/a2a_agent_tool.py | 78 +++ ms_agent/tools/acp_agent_tool.py | 79 +++ ms_agent/tools/tool_manager.py | 12 + requirements/a2a.txt | 1 + requirements/acp.txt | 1 + setup.py | 2 + tests/test_a2a/__init__.py | 0 tests/test_a2a/test_a2a_e2e.py | 104 ++++ tests/test_a2a/test_a2a_protocol.py | 175 ++++++ tests/test_a2a/test_a2a_unit.py | 304 +++++++++ tests/test_acp/__init__.py | 0 tests/test_acp/proxy_opencode.yaml | 10 + tests/test_acp/test_acp_e2e.py | 74 +++ tests/test_acp/test_acp_e2e_real.py | 205 ++++++ tests/test_acp/test_acp_protocol.py | 655 ++++++++++++++++++++ tests/test_acp/test_acp_proxy.py | 255 ++++++++ tests/test_acp/test_acp_proxy_e2e.py | 291 +++++++++ tests/test_acp/test_acp_real.py | 821 +++++++++++++++++++++++++ tests/test_acp/test_acp_unit.py | 384 ++++++++++++ tests/test_acp/test_openai_config.yaml | 18 + 44 files changed, 7137 insertions(+) create mode 100644 ms_agent/a2a/__init__.py create mode 100644 ms_agent/a2a/agent_card.py create mode 100644 ms_agent/a2a/client.py create mode 100644 ms_agent/a2a/errors.py create mode 100644 ms_agent/a2a/executor.py create mode 100644 ms_agent/a2a/session_store.py create mode 100644 ms_agent/a2a/translator.py create mode 100644 ms_agent/acp/__init__.py create mode 100644 ms_agent/acp/client.py create mode 100644 ms_agent/acp/config.py create mode 100644 ms_agent/acp/errors.py create mode 100644 ms_agent/acp/http_adapter.py create mode 100644 ms_agent/acp/permissions.py create mode 100644 ms_agent/acp/proxy.py create mode 100644 ms_agent/acp/proxy_session.py create mode 100644 ms_agent/acp/registry.py create mode 100644 ms_agent/acp/server.py create mode 100644 ms_agent/acp/session_store.py create mode 100644 ms_agent/acp/translator.py create mode 100644 ms_agent/cli/a2a_cmd.py create mode 100644 ms_agent/cli/acp_cmd.py create mode 100644 ms_agent/cli/acp_proxy_cmd.py create mode 100644 ms_agent/tools/a2a_agent_tool.py create mode 100644 ms_agent/tools/acp_agent_tool.py create mode 100644 requirements/a2a.txt create mode 100644 requirements/acp.txt create mode 100644 tests/test_a2a/__init__.py create mode 100644 tests/test_a2a/test_a2a_e2e.py create mode 100644 tests/test_a2a/test_a2a_protocol.py create mode 100644 tests/test_a2a/test_a2a_unit.py create mode 100644 tests/test_acp/__init__.py create mode 100644 tests/test_acp/proxy_opencode.yaml create mode 100644 tests/test_acp/test_acp_e2e.py create mode 100644 tests/test_acp/test_acp_e2e_real.py create mode 100644 tests/test_acp/test_acp_protocol.py create mode 100644 tests/test_acp/test_acp_proxy.py create mode 100644 tests/test_acp/test_acp_proxy_e2e.py create mode 100644 tests/test_acp/test_acp_real.py create mode 100644 tests/test_acp/test_acp_unit.py create mode 100644 tests/test_acp/test_openai_config.yaml diff --git a/.gitignore b/.gitignore index f526ef422..febce981a 100644 --- a/.gitignore +++ b/.gitignore @@ -120,6 +120,7 @@ venv.bak/ .vscode .idea +.cursor # custom *.pkl diff --git a/ms_agent/a2a/__init__.py b/ms_agent/a2a/__init__.py new file mode 100644 index 000000000..28231015f --- /dev/null +++ b/ms_agent/a2a/__init__.py @@ -0,0 +1,63 @@ +"""A2A (Agent-to-Agent) protocol support for ms-agent. + +This package provides: + +- **Server**: ``MSAgentA2AExecutor`` bridges A2A requests to ms-agent's + agent runtime, allowing ms-agent to be called by remote A2A clients. +- **Client**: ``A2AClientManager`` sends messages to remote A2A agents + over HTTP, enabling ms-agent to delegate work to external agents. +- **Agent Card**: ``build_agent_card`` / ``generate_agent_card_json`` + produce the A2A discovery document from ms-agent config. + +All SDK-dependent imports are lazy so the package can be imported even +when ``a2a-sdk`` is not installed (the tools and CLI will gracefully +degrade). +""" + +from .client import A2AClientManager +from .errors import (A2AServerError, AgentLoadError, ConfigError, LLMError, + MaxTasksError, RateLimitError, TaskNotFoundError, + wrap_a2a_error) +from .session_store import A2AAgentStore, A2ATaskEntry +from .translator import (a2a_message_to_ms_messages, collect_full_response, + extract_text_from_a2a_message, ms_messages_to_text) + + +def __getattr__(name): + """Lazy-load SDK-dependent symbols on first access.""" + if name == 'MSAgentA2AExecutor': + from .executor import MSAgentA2AExecutor + return MSAgentA2AExecutor + if name == 'configure_a2a_logging': + from .executor import configure_a2a_logging + return configure_a2a_logging + if name == 'build_agent_card': + from .agent_card import build_agent_card + return build_agent_card + if name == 'generate_agent_card_json': + from .agent_card import generate_agent_card_json + return generate_agent_card_json + raise AttributeError(f'module {__name__!r} has no attribute {name!r}') + + +__all__ = [ + 'A2AAgentStore', + 'A2AClientManager', + 'A2AServerError', + 'A2ATaskEntry', + 'AgentLoadError', + 'ConfigError', + 'LLMError', + 'MSAgentA2AExecutor', + 'MaxTasksError', + 'RateLimitError', + 'TaskNotFoundError', + 'a2a_message_to_ms_messages', + 'build_agent_card', + 'collect_full_response', + 'configure_a2a_logging', + 'extract_text_from_a2a_message', + 'generate_agent_card_json', + 'ms_messages_to_text', + 'wrap_a2a_error', +] diff --git a/ms_agent/a2a/agent_card.py b/ms_agent/a2a/agent_card.py new file mode 100644 index 000000000..b7b3af288 --- /dev/null +++ b/ms_agent/a2a/agent_card.py @@ -0,0 +1,114 @@ +import os +from typing import Any, Dict, List + +import json +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_DEFAULT_VERSION = '0.1.0' + + +def build_agent_card( + config_path: str | None = None, + host: str = '0.0.0.0', + port: int = 5000, + version: str = _DEFAULT_VERSION, + title: str = 'MS-Agent', + description: str = ('Lightweight framework for empowering agents ' + 'with autonomous exploration'), + skills: list[dict] | None = None, +) -> dict: + """Build an A2A ``AgentCard`` dict from ms-agent config. + + The returned dict matches the A2A AgentCard schema and can be passed + directly to ``a2a.types.AgentCard(**card_dict)`` or serialised to JSON. + """ + from a2a.types import ( + AgentCard, + AgentCapabilities, + AgentSkill, + ) + + resolved_host = host if host != '0.0.0.0' else 'localhost' + url = f'http://{resolved_host}:{port}/' + + if config_path and os.path.exists(config_path): + try: + from ms_agent.config.config import Config + config = Config.from_task(config_path) + cfg_desc = getattr(config, 'description', None) + if cfg_desc: + description = str(cfg_desc) + cfg_name = getattr(config, 'name', None) + if cfg_name: + title = str(cfg_name) + except Exception: + logger.debug( + 'Could not load config for agent card metadata', exc_info=True) + + skill_list: list[AgentSkill] = [] + if skills: + for s in skills: + skill_list.append( + AgentSkill( + id=s.get('id', 'general'), + name=s.get('name', title), + description=s.get('description', description), + tags=s.get('tags', []), + examples=s.get('examples', []), + )) + else: + skill_list.append( + AgentSkill( + id='general', + name=title, + description=description, + tags=['general', 'agent'], + examples=['Help me research a topic'], + )) + + card = AgentCard( + name=title.lower().replace(' ', '-'), + description=description, + url=url, + version=version, + capabilities=AgentCapabilities(streaming=True), + skills=skill_list, + defaultInputModes=['text'], + defaultOutputModes=['text'], + ) + return card + + +def generate_agent_card_json( + config_path: str | None = None, + output_path: str = 'agent-card.json', + host: str = '0.0.0.0', + port: int = 5000, + version: str = _DEFAULT_VERSION, + title: str = 'MS-Agent', + description: str = ('Lightweight framework for empowering agents ' + 'with autonomous exploration'), + skills: list[dict] | None = None, +) -> dict: + """Build an agent card and optionally write it to disk as JSON.""" + card = build_agent_card( + config_path=config_path, + host=host, + port=port, + version=version, + title=title, + description=description, + skills=skills, + ) + + card_dict = card.model_dump(by_alias=True, exclude_none=True) + + if output_path: + abs_path = os.path.abspath(output_path) + with open(abs_path, 'w') as f: + json.dump(card_dict, f, indent=2) + logger.info('A2A Agent Card written to %s', abs_path) + + return card_dict diff --git a/ms_agent/a2a/client.py b/ms_agent/a2a/client.py new file mode 100644 index 000000000..b2047c0bf --- /dev/null +++ b/ms_agent/a2a/client.py @@ -0,0 +1,129 @@ +import os +from typing import Any, Dict, List, Optional + +import httpx +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +class A2AClientManager: + """Lifecycle manager for remote A2A agent connections. + + Each configured agent (from ``a2a_agents`` in the YAML config) is + represented by its URL. Connections use HTTP via ``httpx`` and the + A2A SDK's ``ClientFactory``. + """ + + def __init__(self, a2a_agents_config: dict | None = None): + self._config: Dict[str, dict] = a2a_agents_config or {} + self._http_client: Optional[httpx.AsyncClient] = None + + def _get_http_client(self) -> httpx.AsyncClient: + if self._http_client is None or self._http_client.is_closed: + self._http_client = httpx.AsyncClient(timeout=300.0) + return self._http_client + + async def call_agent( + self, + agent_name: str, + query: str, + ) -> str: + """Send a message to a remote A2A agent and return the text response. + + Discovers the agent via its Agent Card, then sends a message using + the A2A SDK client. Supports both streaming and non-streaming + responses. + """ + cfg = self._config.get(agent_name) + if cfg is None: + return f'Error: A2A agent "{agent_name}" not configured' + + url = cfg.get('url', '') + if not url: + return f'Error: A2A agent "{agent_name}" has no URL configured' + + try: + from a2a.client import ( + A2ACardResolver, + ClientConfig, + ClientFactory, + ) + from a2a.client.helpers import create_text_message_object + + http_client = self._get_http_client() + + auth_headers = self._build_auth_headers(cfg) + if auth_headers: + http_client = httpx.AsyncClient( + timeout=300.0, headers=auth_headers) + + resolver = A2ACardResolver(httpx_client=http_client, base_url=url) + card = await resolver.get_agent_card() + + factory = ClientFactory( + config=ClientConfig(httpx_client=http_client)) + client = factory.create(card) + + message = create_text_message_object(content=query) + result_parts: list[str] = [] + + async for event in client.send_message(message): + if hasattr(event, 'parts'): + for part in event.parts: + part_obj = part + if hasattr(part, 'root'): + part_obj = part.root + if hasattr(part_obj, 'text'): + result_parts.append(part_obj.text) + elif isinstance(event, tuple) and len(event) == 2: + task, update = event + if update and hasattr(update, 'status'): + status = update.status + msg = getattr(status, 'message', None) + if msg and hasattr(msg, 'parts'): + for part in msg.parts: + part_obj = part + if hasattr(part, 'root'): + part_obj = part.root + if hasattr(part_obj, 'text'): + result_parts.append(part_obj.text) + if task and hasattr(task, 'artifacts'): + for artifact in (task.artifacts or []): + for part in (artifact.parts or []): + part_obj = part + if hasattr(part, 'root'): + part_obj = part.root + if hasattr(part_obj, 'text'): + result_parts.append(part_obj.text) + + return '\n'.join(result_parts) if result_parts else '(no output)' + + except Exception as e: + logger.error( + 'A2A call to %s failed: %s', agent_name, e, exc_info=True) + return f'Error calling A2A agent "{agent_name}": {e}' + + @staticmethod + def _build_auth_headers(cfg: dict) -> dict[str, str]: + """Build authentication headers from agent config.""" + auth = cfg.get('auth') + if not auth: + return {} + + auth_type = auth.get('type', '').lower() + if auth_type == 'bearer': + token_env = auth.get('token_env', '') + token = auth.get('token', '') or os.environ.get(token_env, '') + if token: + return {'Authorization': f'Bearer {token}'} + + return {} + + def list_agents(self) -> List[str]: + return list(self._config.keys()) + + async def close_all(self) -> None: + if self._http_client and not self._http_client.is_closed: + await self._http_client.aclose() + self._http_client = None diff --git a/ms_agent/a2a/errors.py b/ms_agent/a2a/errors.py new file mode 100644 index 000000000..659aa700e --- /dev/null +++ b/ms_agent/a2a/errors.py @@ -0,0 +1,80 @@ +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +class A2AServerError(Exception): + """Base exception for A2A server-side errors in ms-agent.""" + + def __init__(self, code: int, message: str, data: dict | None = None): + self.code = code + self.message = message + self.data = data or {} + super().__init__(message) + + +class TaskNotFoundError(A2AServerError): + + def __init__(self, task_id: str): + super().__init__(-32001, 'Task not found', {'taskId': task_id}) + + +class AgentLoadError(A2AServerError): + + def __init__(self, detail: str): + super().__init__(-32002, 'Failed to load agent', {'detail': detail}) + + +class LLMError(A2AServerError): + + def __init__(self, detail: str): + super().__init__(-32003, 'LLM generation failed', {'detail': detail}) + + +class RateLimitError(A2AServerError): + + def __init__(self, detail: str = ''): + super().__init__(-32004, 'Rate limit exceeded', {'detail': detail}) + + +class ConfigError(A2AServerError): + + def __init__(self, detail: str): + super().__init__(-32005, 'Invalid configuration', {'detail': detail}) + + +class MaxTasksError(A2AServerError): + + def __init__(self, max_tasks: int): + super().__init__(-32006, 'Maximum concurrent tasks reached', + {'max': max_tasks}) + + +_EXCEPTION_MAP: list[tuple[type, int, str]] = [ + (FileNotFoundError, -32002, 'Resource not found'), + (PermissionError, -32000, 'Permission denied'), + (TimeoutError, -32004, 'Request timed out'), + (ValueError, -32602, 'Invalid params'), +] + + +def wrap_a2a_error(exc: Exception) -> dict: + """Convert an ms-agent exception into a JSON-RPC-style error dict. + + Returns a dict with ``code``, ``message``, and ``data`` keys suitable + for logging or constructing an A2A ``ServerError``. + """ + if isinstance(exc, A2AServerError): + return {'code': exc.code, 'message': exc.message, 'data': exc.data} + + for exc_type, code, msg in _EXCEPTION_MAP: + if isinstance(exc, exc_type): + return {'code': code, 'message': msg, 'data': {'detail': str(exc)}} + + return { + 'code': -32603, + 'message': 'Internal error', + 'data': { + 'detail': str(exc) + } + } diff --git a/ms_agent/a2a/executor.py b/ms_agent/a2a/executor.py new file mode 100644 index 000000000..9545c5068 --- /dev/null +++ b/ms_agent/a2a/executor.py @@ -0,0 +1,173 @@ +import logging +import os +import sys +from contextlib import contextmanager +from typing import Any + +from a2a.server.agent_execution import AgentExecutor, RequestContext +from a2a.server.events import EventQueue +from a2a.server.tasks import TaskUpdater +from a2a.types import Part, TaskState, TextPart +from a2a.utils import new_agent_text_message, new_task +from ms_agent.utils.logger import get_logger + +from .errors import wrap_a2a_error +from .session_store import A2AAgentStore +from .translator import extract_text_from_a2a_message, ms_messages_to_text + +logger = get_logger() + + +def configure_a2a_logging(log_file: str | None = None) -> None: + """Set up logging for the A2A server process.""" + handler: logging.Handler + if log_file: + handler = logging.FileHandler(log_file) + else: + handler = logging.StreamHandler(sys.stderr) + + fmt = logging.Formatter( + '%(asctime)s [%(name)s] %(levelname)s: %(message)s') + handler.setFormatter(fmt) + + root = logging.getLogger() + root.handlers.clear() + root.addHandler(handler) + root.setLevel(logging.INFO) + + +class MSAgentA2AExecutor(AgentExecutor): + """A2A ``AgentExecutor`` backed by ms-agent's ``LLMAgent``. + + Each A2A task maps to an agent instance managed by ``A2AAgentStore``. + The executor translates the incoming A2A message to a user query, + runs the agent, and streams updates back through the event queue. + """ + + def __init__( + self, + config_path: str, + trust_remote_code: bool = False, + max_tasks: int = 8, + task_timeout: int = 3600, + ) -> None: + self.config_path = config_path + self.trust_remote_code = trust_remote_code + self._store = A2AAgentStore( + config_path=config_path, + trust_remote_code=trust_remote_code, + max_tasks=max_tasks, + task_timeout=task_timeout, + ) + + @staticmethod + @contextmanager + def _suppress_stdout(): + """Redirect stdout to devnull while running agent logic. + + ``LLMAgent.step()`` writes streaming tokens to ``sys.stdout``, + which would corrupt any stdio-based transport. + """ + real_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + try: + yield + finally: + sys.stdout.close() + sys.stdout = real_stdout + + async def execute( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Execute the agent's logic for an inbound A2A message.""" + user_text = context.get_user_input() + if not user_text and context.message: + user_text = extract_text_from_a2a_message(context.message) + + task = context.current_task + if not task: + task = new_task(context.message) + await event_queue.enqueue_event(task) + + updater = TaskUpdater(event_queue, task.id, task.context_id) + + try: + await updater.update_status(TaskState.working) + + entry = await self._store.get_or_create(task.id) + entry.is_running = True + + try: + with self._suppress_stdout(): + result = await entry.agent.run(user_text, stream=True) + + if hasattr(result, '__aiter__'): + async for chunk in result: + entry.messages = chunk + if entry.cancelled: + await updater.cancel() + return + elif isinstance(result, list): + entry.messages = result + + response_text = ms_messages_to_text(entry.messages) + if not response_text: + from .translator import collect_full_response + response_text = collect_full_response(entry.messages) + + if response_text: + await updater.add_artifact( + [Part(root=TextPart(text=response_text))], + name='response', + ) + await updater.complete() + else: + await updater.complete( + new_agent_text_message( + '(Agent completed with no text output)', + task.context_id, + task.id, + )) + + finally: + entry.is_running = False + + except Exception as e: + logger.error( + 'A2A execute error for task %s: %s', task.id, e, exc_info=True) + err_info = wrap_a2a_error(e) + try: + await updater.failed( + new_agent_text_message( + f'Error: {err_info["message"]}', + task.context_id, + task.id, + )) + except Exception: + logger.warning('Failed to send error status', exc_info=True) + + async def cancel( + self, + context: RequestContext, + event_queue: EventQueue, + ) -> None: + """Cancel a running task.""" + task_id = context.task_id + updater = TaskUpdater( + event_queue, + task_id, + context.context_id, + ) + + entry = self._store.get(task_id) if task_id else None + if entry: + entry.request_cancel() + + await updater.cancel() + logger.info('A2A task %s cancel requested', task_id) + + async def cleanup(self) -> None: + """Shut down all agent instances.""" + await self._store.close_all() diff --git a/ms_agent/a2a/session_store.py b/ms_agent/a2a/session_store.py new file mode 100644 index 000000000..811765df8 --- /dev/null +++ b/ms_agent/a2a/session_store.py @@ -0,0 +1,166 @@ +import asyncio +import os +import uuid +from dataclasses import dataclass, field +from time import monotonic +from typing import Any, Dict, List, Optional + +from ms_agent.agent.base import Agent +from ms_agent.agent.loader import AgentLoader +from ms_agent.config.config import Config +from ms_agent.llm.utils import Message +from ms_agent.utils.logger import get_logger +from omegaconf import DictConfig + +from .errors import AgentLoadError, ConfigError, MaxTasksError + +logger = get_logger() + + +@dataclass +class A2ATaskEntry: + """In-memory state for a single A2A task's backing agent.""" + + task_id: str + agent: Agent + config: DictConfig + config_path: str + created_at: float + last_activity: float + messages: List[Message] = field(default_factory=list) + is_running: bool = False + _cancel_event: asyncio.Event = field(default_factory=asyncio.Event) + + def touch(self) -> None: + self.last_activity = monotonic() + + def request_cancel(self) -> None: + self._cancel_event.set() + if self.agent.runtime is not None: + self.agent.runtime.should_stop = True + + @property + def cancelled(self) -> bool: + return self._cancel_event.is_set() + + +class A2AAgentStore: + """Manages agent instances backing A2A tasks. + + Parameters: + config_path: Path to the agent YAML config. + trust_remote_code: Whether to trust remote code in config. + max_tasks: Upper bound on concurrent agent instances. + task_timeout: Seconds of inactivity before eviction eligibility. + """ + + def __init__( + self, + config_path: str, + trust_remote_code: bool = False, + max_tasks: int = 8, + task_timeout: int = 3600, + cleanup_interval: int = 300, + ): + self.config_path = config_path + self.trust_remote_code = trust_remote_code + self.max_tasks = max_tasks + self.task_timeout = task_timeout + self.cleanup_interval = cleanup_interval + self._tasks: Dict[str, A2ATaskEntry] = {} + self._cleanup_task: Optional[asyncio.Task] = None + + async def get_or_create(self, task_id: str) -> A2ATaskEntry: + """Return an existing task entry or create a new one.""" + if task_id in self._tasks: + entry = self._tasks[task_id] + entry.touch() + return entry + + if len(self._tasks) >= self.max_tasks: + evicted_id = self._evict_lru() + if evicted_id is None: + raise MaxTasksError(self.max_tasks) + await self._cleanup_entry(evicted_id) + + if not self.config_path or not os.path.exists(self.config_path): + raise ConfigError(f'Config not found: {self.config_path}') + + try: + config = Config.from_task(self.config_path) + agent = AgentLoader.build( + config_dir_or_id=self.config_path, + config=config, + trust_remote_code=self.trust_remote_code, + ) + except Exception as e: + raise AgentLoadError(str(e)) from e + + now = monotonic() + entry = A2ATaskEntry( + task_id=task_id, + agent=agent, + config=config, + config_path=self.config_path, + created_at=now, + last_activity=now, + ) + self._tasks[task_id] = entry + self._ensure_cleanup_running() + logger.info('A2A agent created for task: %s (config=%s)', task_id, + self.config_path) + return entry + + def get(self, task_id: str) -> A2ATaskEntry | None: + entry = self._tasks.get(task_id) + if entry: + entry.touch() + return entry + + async def remove(self, task_id: str) -> None: + await self._cleanup_entry(task_id) + + async def close_all(self) -> None: + for tid in list(self._tasks): + await self._cleanup_entry(tid) + if self._cleanup_task and not self._cleanup_task.done(): + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + def _evict_lru(self) -> Optional[str]: + idle = [(tid, t) for tid, t in self._tasks.items() if not t.is_running] + if not idle: + return None + return min(idle, key=lambda x: x[1].last_activity)[0] + + async def _cleanup_entry(self, task_id: str) -> None: + entry = self._tasks.pop(task_id, None) + if entry is None: + return + try: + if hasattr(entry.agent, 'cleanup_tools'): + await entry.agent.cleanup_tools() + except Exception: + logger.warning( + 'Error cleaning up A2A task %s', task_id, exc_info=True) + logger.info('A2A agent removed for task: %s', task_id) + + def _ensure_cleanup_running(self) -> None: + if self._cleanup_task is None or self._cleanup_task.done(): + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + + async def _periodic_cleanup(self) -> None: + while True: + await asyncio.sleep(self.cleanup_interval) + now = monotonic() + expired = [ + tid for tid, t in self._tasks.items() + if (now + - t.last_activity > self.task_timeout and not t.is_running) + ] + for tid in expired: + logger.info('Evicting timed-out A2A task %s', tid) + await self._cleanup_entry(tid) diff --git a/ms_agent/a2a/translator.py b/ms_agent/a2a/translator.py new file mode 100644 index 000000000..96d603fa4 --- /dev/null +++ b/ms_agent/a2a/translator.py @@ -0,0 +1,109 @@ +from typing import Any, List + +import json +from ms_agent.llm.utils import Message +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +def extract_text_from_a2a_message(message: Any) -> str: + """Extract concatenated text from an A2A ``Message`` object. + + Handles ``TextPart``, ``FilePart``, and ``DataPart`` within + ``message.parts``. + """ + if message is None: + return '' + + parts = getattr(message, 'parts', None) + if not parts: + return str(message) if message else '' + + text_parts: list[str] = [] + for part in parts: + part_obj = part + if hasattr(part, 'root'): + part_obj = part.root + + kind = getattr(part_obj, 'type', None) + if kind == 'text' or hasattr(part_obj, 'text'): + text_parts.append(getattr(part_obj, 'text', str(part_obj))) + elif kind == 'file' or hasattr(part_obj, 'file'): + file_obj = getattr(part_obj, 'file', None) + if file_obj: + name = getattr(file_obj, 'name', 'unnamed') + mime = getattr(file_obj, 'mimeType', '') + if hasattr(file_obj, 'bytes'): + text_parts.append( + f'[File: {name} ({mime}), binary content]') + elif hasattr(file_obj, 'uri'): + text_parts.append( + f'[File: {name} ({mime}), uri={file_obj.uri}]') + elif kind == 'data' or hasattr(part_obj, 'data'): + data = getattr(part_obj, 'data', part_obj) + try: + text_parts.append(json.dumps(data, default=str)) + except (TypeError, ValueError): + text_parts.append(str(data)) + else: + text_parts.append(str(part_obj)) + + return '\n'.join(text_parts) + + +def a2a_message_to_ms_messages( + a2a_message: Any, + existing_messages: List[Message] | None = None, +) -> List[Message]: + """Convert an inbound A2A ``Message`` to ms-agent ``Message`` list. + + If ``existing_messages`` is provided, the new user message is appended + (for multi-turn); otherwise a fresh list is returned. + """ + user_text = extract_text_from_a2a_message(a2a_message) + user_msg = Message(role='user', content=user_text) + + if existing_messages is not None: + existing_messages.append(user_msg) + return existing_messages + return [user_msg] + + +def ms_messages_to_text(messages: List[Message]) -> str: + """Extract the final assistant response text from ms-agent messages. + + Scans backwards for the last assistant message and returns its text + content. + """ + for msg in reversed(messages or []): + if msg.role == 'assistant': + content = msg.content + if isinstance(content, str) and content.strip(): + return content + if isinstance(content, list): + parts = [] + for block in content: + if isinstance(block, dict): + parts.append(block.get('text', str(block))) + else: + parts.append(str(block)) + text = '\n'.join(parts) + if text.strip(): + return text + return '' + + +def collect_full_response(messages: List[Message]) -> str: + """Collect all assistant text across the full message history. + + Useful for assembling a complete response from multi-step agent runs + where the LLM interleaves tool calls with text fragments. + """ + parts: list[str] = [] + for msg in (messages or []): + if msg.role == 'assistant': + content = msg.content + if isinstance(content, str) and content.strip(): + parts.append(content) + return '\n\n'.join(parts) if parts else '' diff --git a/ms_agent/acp/__init__.py b/ms_agent/acp/__init__.py new file mode 100644 index 000000000..b168f45d8 --- /dev/null +++ b/ms_agent/acp/__init__.py @@ -0,0 +1,20 @@ +from ms_agent.acp.errors import ACPError, wrap_agent_error +from ms_agent.acp.permissions import PermissionPolicy +from ms_agent.acp.proxy import MSAgentACPProxy +from ms_agent.acp.proxy_session import ProxySessionStore +from ms_agent.acp.registry import generate_agent_manifest +from ms_agent.acp.server import MSAgentACPServer +from ms_agent.acp.session_store import ACPSessionStore +from ms_agent.acp.translator import ACPTranslator + +__all__ = [ + 'MSAgentACPServer', + 'MSAgentACPProxy', + 'ACPSessionStore', + 'ACPTranslator', + 'ACPError', + 'PermissionPolicy', + 'ProxySessionStore', + 'generate_agent_manifest', + 'wrap_agent_error', +] diff --git a/ms_agent/acp/client.py b/ms_agent/acp/client.py new file mode 100644 index 000000000..02ad4b0e2 --- /dev/null +++ b/ms_agent/acp/client.py @@ -0,0 +1,132 @@ +import asyncio +from contextlib import asynccontextmanager +from typing import Any, Dict, List, Optional + +from acp import spawn_agent_process, text_block +from acp.interfaces import Client +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +class _CollectorClient(Client): + """Minimal ACP client that accumulates streamed text from an external agent. + + Permission requests are resolved by the configured *policy*. + """ + + def __init__(self, permission_policy: str = 'auto_approve'): + self.collected: Dict[str, List[str]] = {} + self.permission_policy = permission_policy + + async def session_update(self, session_id: str, update: Any, + **kwargs: Any) -> None: + update_type = getattr(update, 'session_update', None) + if update_type == 'agent_message_chunk': + content = getattr(update, 'content', None) + if content is not None: + text = getattr(content, 'text', None) or str(content) + self.collected.setdefault(session_id, []).append(text) + + async def request_permission(self, options: list, session_id: str, + tool_call: Any, **kwargs: Any): + from acp.schema import (RequestPermissionResponse, AllowedOutcome, + DeniedOutcome) + if self.permission_policy == 'auto_approve': + allow = next( + (o for o in options + if 'allow' in (getattr(o, 'kind', '') or '')), + None, + ) + if allow: + option_id = getattr(allow, 'option_id', 'allow_once') + return RequestPermissionResponse( + outcome=AllowedOutcome( + outcome='selected', + option_id=option_id, + )) + return RequestPermissionResponse( + outcome=DeniedOutcome(outcome='cancelled')) + + def get_output(self, session_id: str) -> str: + parts = self.collected.get(session_id, []) + return ''.join(parts) + + def clear(self, session_id: str) -> None: + self.collected.pop(session_id, None) + + +class ACPClientManager: + """Lifecycle manager for external ACP agent connections, + which spawns external agent processes via the SDK's + ``spawn_agent_process`` and provides a high-level interface for sending + prompts, collecting streamed output, and handling permission callbacks. + + Each configured agent (from ``acp_agents`` in the YAML config) is + represented by its *command + args*. Connections are opened lazily + and cached for the lifetime of the manager. + """ + + def __init__(self, acp_agents_config: dict | None = None): + self._config: Dict[str, dict] = acp_agents_config or {} + self._clients: Dict[str, _CollectorClient] = {} + self._connections: Dict[str, Any] = {} + self._processes: Dict[str, Any] = {} + self._ctx_managers: Dict[str, Any] = {} + + async def call_agent( + self, + agent_name: str, + query: str, + cwd: str = '/tmp', + ) -> str: + """Send a single-turn prompt to an external ACP agent and return + the accumulated text response. + """ + cfg = self._config.get(agent_name) + if cfg is None: + return f'Error: ACP agent "{agent_name}" not configured' + + policy = cfg.get('permission_policy', 'auto_approve') + client = _CollectorClient(permission_policy=policy) + + command = cfg['command'] + args = cfg.get('args', []) + + try: + ctx = spawn_agent_process(client, command, *args) + conn, proc = await ctx.__aenter__() + + try: + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd=cwd, mcp_servers=[]) + sid = session.session_id + + await conn.prompt( + session_id=sid, + prompt=[text_block(query)], + ) + return client.get_output(sid) or '(no output)' + finally: + try: + await ctx.__aexit__(None, None, None) + except Exception: + pass + except Exception as e: + logger.error( + 'ACP call to %s failed: %s', agent_name, e, exc_info=True) + return f'Error calling ACP agent "{agent_name}": {e}' + + def list_agents(self) -> List[str]: + return list(self._config.keys()) + + async def close_all(self) -> None: + for name, ctx in list(self._ctx_managers.items()): + try: + await ctx.__aexit__(None, None, None) + except Exception: + pass + self._connections.clear() + self._processes.clear() + self._clients.clear() + self._ctx_managers.clear() diff --git a/ms_agent/acp/config.py b/ms_agent/acp/config.py new file mode 100644 index 000000000..781df0535 --- /dev/null +++ b/ms_agent/acp/config.py @@ -0,0 +1,76 @@ +"""ACP configuration helpers: build configOptions from ms-agent config.""" + +from __future__ import annotations +from typing import Any + +from acp.schema import (SessionConfigOptionSelect, SessionConfigSelect, + SessionConfigSelectOption, SessionMode, + SessionModeState) +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +def build_config_options( + config, + available_models: list[str] | None = None, +) -> list | None: + """Derive ACP ``configOptions`` from an ms-agent DictConfig. + + Returns a list of ``SessionConfigOptionSelect`` selectors that ACP + clients can render for the user (model picker, etc.). + """ + options: list = [] + + model_id = _get_model_id(config) + if model_id: + models = available_models or [model_id] + values = [SessionConfigSelectOption(value=m, name=m) for m in models] + options.append( + SessionConfigOptionSelect( + type='select', + id='model', + name='LLM Model', + category='model', + current_value=model_id, + options=values, + )) + + return options if options else None + + +def build_session_modes() -> SessionModeState | None: + """Build a default mode state for ms-agent sessions.""" + modes = [ + SessionMode( + id='agent', + name='Agent', + description='Full agent mode with tools', + ), + ] + return SessionModeState( + available_modes=modes, + current_mode_id='agent', + ) + + +def apply_config_option(config, config_id: str, value: str) -> bool: + """Apply a config option change to the live agent config. + + Returns True if the option was applied successfully. + """ + from omegaconf import OmegaConf + + if config_id == 'model': + if hasattr(config, 'llm') and hasattr(config.llm, 'model'): + OmegaConf.update(config, 'llm.model', value, merge=True) + logger.info('Config option updated: llm.model = %s', value) + return True + return False + + +def _get_model_id(config) -> str | None: + """Extract the current model identifier from config.""" + if hasattr(config, 'llm') and hasattr(config.llm, 'model'): + return str(config.llm.model) + return None diff --git a/ms_agent/acp/errors.py b/ms_agent/acp/errors.py new file mode 100644 index 000000000..92b4dd0b8 --- /dev/null +++ b/ms_agent/acp/errors.py @@ -0,0 +1,77 @@ +from acp import RequestError + + +class ACPError(Exception): + """Base exception for ACP-specific errors in ms-agent.""" + + def __init__(self, code: int, message: str, data: dict | None = None): + self.code = code + self.message = message + self.data = data or {} + super().__init__(message) + + +class SessionNotFoundError(ACPError): + + def __init__(self, session_id: str): + super().__init__(-32001, 'Session not found', + {'sessionId': session_id}) + + +class ResourceNotFoundError(ACPError): + + def __init__(self, path: str): + super().__init__(-32002, 'Resource not found', {'path': path}) + + +class LLMError(ACPError): + + def __init__(self, detail: str): + super().__init__(-32003, 'LLM generation failed', {'detail': detail}) + + +class RateLimitError(ACPError): + + def __init__(self, detail: str = ''): + super().__init__(-32004, 'Rate limit exceeded', {'detail': detail}) + + +class ConfigError(ACPError): + + def __init__(self, detail: str): + super().__init__(-32005, 'Invalid configuration', {'detail': detail}) + + +class MaxSessionsError(ACPError): + + def __init__(self, max_sessions: int): + super().__init__(-32006, 'Maximum concurrent sessions reached', + {'max': max_sessions}) + + +# Map known ms-agent / Python exception types to ACP JSON-RPC errors. +_EXCEPTION_MAP: list[tuple[type, int, str]] = [ + (FileNotFoundError, -32002, 'Resource not found'), + (PermissionError, -32000, 'Permission denied'), + (TimeoutError, -32004, 'Request timed out'), + (ValueError, -32602, 'Invalid params'), +] + + +def wrap_agent_error(exc: Exception) -> RequestError: + """Convert an ms-agent exception into an ``acp.RequestError``. + + ``RequestError`` is what the ACP SDK expects to be raised inside + agent method handlers; it serialises to a proper JSON-RPC error object. + """ + if isinstance(exc, ACPError): + return RequestError(exc.code, exc.message, exc.data) + + if isinstance(exc, RequestError): + return exc + + for exc_type, code, msg in _EXCEPTION_MAP: + if isinstance(exc, exc_type): + return RequestError(code, msg, {'detail': str(exc)}) + + return RequestError(-32603, 'Internal error', {'detail': str(exc)}) diff --git a/ms_agent/acp/http_adapter.py b/ms_agent/acp/http_adapter.py new file mode 100644 index 000000000..f7f241f21 --- /dev/null +++ b/ms_agent/acp/http_adapter.py @@ -0,0 +1,268 @@ +import asyncio +import os +from typing import Any + +import json +from fastapi import APIRouter, Depends, Header, HTTPException, Request +from fastapi.responses import JSONResponse, StreamingResponse +from ms_agent.acp.server import MSAgentACPServer +from ms_agent.utils.logger import get_logger +from pydantic import BaseModel + +logger = get_logger() + +router = APIRouter(prefix='/api/acp', tags=['ACP Internal API']) + + +class RPCRequest(BaseModel): + jsonrpc: str = '2.0' + id: int | str | None = None + method: str + params: dict = {} + + +# Module-level server instance; set by ``configure_http_adapter``. +_server = None +_api_key: str | None = None + + +def configure_http_adapter( + config_path: str, + trust_remote_code: bool = False, + max_sessions: int = 8, + session_timeout: int = 3600, + api_key: str | None = None, +) -> APIRouter: + """Initialise the module-level ACP server and return the router. + + Call this before mounting the router into a FastAPI app. + """ + global _server, _api_key + + _server = MSAgentACPServer( + config_path=config_path, + trust_remote_code=trust_remote_code, + max_sessions=max_sessions, + session_timeout=session_timeout, + ) + _api_key = api_key or os.environ.get('MS_AGENT_ACP_API_KEY') + + _DummyConn.server = _server + _server.on_connect(_DummyConn()) + + return router + + +class _DummyConn: + """Minimal stand-in for the SDK ``ClientSideConnection``. + + When the server runs over stdio the SDK provides a real connection + object. Over HTTP we intercept ``session_update`` calls and + stream them back as SSE events instead. + """ + server = None + + def __init__(self): + self._queues: dict[str, asyncio.Queue] = {} + + def get_queue(self, session_id: str) -> asyncio.Queue: + if session_id not in self._queues: + self._queues[session_id] = asyncio.Queue() + return self._queues[session_id] + + async def session_update(self, session_id: str, update: Any, + **kwargs) -> None: + q = self.get_queue(session_id) + data = update.model_dump( + by_alias=True) if hasattr(update, 'model_dump') else update + await q.put(data) + + async def request_permission(self, session_id: str, tool_call: Any, + options: list, **kwargs) -> Any: + allow = next( + (o for o in options if 'allow' in (getattr(o, 'kind', '') or '')), + None, + ) + if allow: + from types import SimpleNamespace + return SimpleNamespace( + outcome={ + 'outcome': 'selected', + 'id': getattr(allow, 'option_id', 'allow_once') + }) + from types import SimpleNamespace + return SimpleNamespace(outcome={'outcome': 'cancelled'}) + + +def _check_api_key(authorization: str | None = Header(None)): + """Simple bearer-token authentication for the internal API.""" + if _api_key is None: + return + if not authorization: + raise HTTPException(401, 'Authorization header required') + parts = authorization.split() + if len(parts) != 2 or parts[0].lower() != 'bearer' or parts[1] != _api_key: + raise HTTPException(403, 'Invalid API key') + + +@router.post('/rpc') +async def rpc_endpoint( + req: RPCRequest, + _auth: None = Depends(_check_api_key), +): + """Single JSON-RPC endpoint for all ACP methods. + + For ``session/prompt`` the response is SSE; for everything else + it is a regular JSON response. + """ + if _server is None: + raise HTTPException(503, 'ACP server not initialised') + + method = req.method + params = req.params + rpc_id = req.id + + try: + if method == 'initialize': + result = await _server.initialize( + protocol_version=params.get('protocolVersion', 1), ) + return _jsonrpc_ok(rpc_id, result) + + elif method == 'session/new': + result = await _server.new_session( + cwd=params.get('cwd', '/tmp'), + mcp_servers=params.get('mcpServers', []), + ) + return _jsonrpc_ok(rpc_id, result) + + elif method == 'session/list': + result = await _server.list_sessions( + cursor=params.get('cursor'), + cwd=params.get('cwd'), + ) + return _jsonrpc_ok(rpc_id, result) + + elif method == 'session/load': + result = await _server.load_session( + cwd=params.get('cwd', '/tmp'), + session_id=params.get('sessionId', ''), + ) + return _jsonrpc_ok(rpc_id, result) + + elif method == 'session/prompt': + return await _handle_prompt_sse(rpc_id, params) + + elif method == 'session/cancel': + await _server.cancel(session_id=params.get('sessionId', '')) + return JSONResponse({ + 'jsonrpc': '2.0', + 'id': rpc_id, + 'result': None + }) + + elif method == 'session/setConfigOption': + result = await _server.set_config_option( + config_id=params.get('configId', ''), + session_id=params.get('sessionId', ''), + value=params.get('value', ''), + ) + return _jsonrpc_ok(rpc_id, result) + + else: + return JSONResponse( + { + 'jsonrpc': '2.0', + 'id': rpc_id, + 'error': { + 'code': -32601, + 'message': f'Method not found: {method}' + } + }, + status_code=200, + ) + + except Exception as e: + from ms_agent.acp.errors import wrap_agent_error, ACPError + rpc_err = wrap_agent_error(e) + return JSONResponse( + { + 'jsonrpc': '2.0', + 'id': rpc_id, + 'error': { + 'code': rpc_err.code, + 'message': rpc_err.message, + 'data': getattr(rpc_err, 'data', None) + } + }, + status_code=200, + ) + + +async def _handle_prompt_sse(rpc_id, params): + """Run a prompt and stream updates as SSE events.""" + from acp import text_block as tb + session_id = params.get('sessionId', '') + prompt_blocks = params.get('prompt', []) + + acp_blocks = [] + for b in prompt_blocks: + if isinstance(b, dict) and b.get('type') == 'text': + acp_blocks.append(tb(b['text'])) + else: + acp_blocks.append(tb(str(b))) + + conn = _server.connection + q = conn.get_queue(session_id) + + async def event_stream(): + prompt_task = asyncio.create_task( + _server.prompt( + prompt=acp_blocks, + session_id=session_id, + )) + try: + while not prompt_task.done(): + try: + update = await asyncio.wait_for(q.get(), timeout=0.5) + yield f'data: {json.dumps(update, default=str)}\n\n' + except asyncio.TimeoutError: + continue + + while not q.empty(): + update = q.get_nowait() + yield f'data: {json.dumps(update, default=str)}\n\n' + + result = prompt_task.result() + final = result.model_dump( + by_alias=True) if hasattr(result, 'model_dump') else result + response = {'jsonrpc': '2.0', 'id': rpc_id, 'result': final} + yield f'data: {json.dumps(response, default=str)}\n\n' + + except Exception as e: + from ms_agent.acp.errors import wrap_agent_error + rpc_err = wrap_agent_error(e) + err_resp = { + 'jsonrpc': '2.0', + 'id': rpc_id, + 'error': { + 'code': rpc_err.code, + 'message': rpc_err.message + } + } + yield f'data: {json.dumps(err_resp)}\n\n' + + return StreamingResponse( + event_stream(), + media_type='text/event-stream', + headers={ + 'Cache-Control': 'no-cache', + 'Connection': 'keep-alive', + 'X-Accel-Buffering': 'no', + }, + ) + + +def _jsonrpc_ok(rpc_id, result) -> JSONResponse: + data = result.model_dump( + by_alias=True) if hasattr(result, 'model_dump') else result + return JSONResponse({'jsonrpc': '2.0', 'id': rpc_id, 'result': data}) diff --git a/ms_agent/acp/permissions.py b/ms_agent/acp/permissions.py new file mode 100644 index 000000000..fd23e4fae --- /dev/null +++ b/ms_agent/acp/permissions.py @@ -0,0 +1,114 @@ +"""Fine-grained permission policies for ACP tool calls. + +Policies control how ms-agent handles ``request_permission`` when a tool +call is about to execute. + +Supported policies: + - ``auto_approve``: silently approve everything (dev/testing) + - ``always_ask``: always prompt the client for approval + - ``remember_choice``: ask once per tool name, remember the answer +""" + +from __future__ import annotations +from typing import Any, Dict, Optional + +from ms_agent.utils.logger import get_logger + +logger = get_logger() + + +class PermissionPolicy: + """Manages permission decisions for tool calls within a session.""" + + def __init__(self, policy: str = 'auto_approve'): + self.policy = policy + self._remembered: Dict[str, bool] = {} + + def should_ask(self, tool_name: str) -> bool: + """Return True if the client should be asked for this tool call.""" + if self.policy == 'auto_approve': + return False + if self.policy == 'remember_choice': + return tool_name not in self._remembered + return True # always_ask + + def auto_decision(self, tool_name: str) -> str | None: + """Return a pre-determined decision if available. + + Returns ``'allow_once'`` for auto-approve, the remembered outcome + for remember_choice, or ``None`` if the client must be asked. + """ + if self.policy == 'auto_approve': + return 'allow_once' + if self.policy == 'remember_choice' and tool_name in self._remembered: + return 'allow_always' if self._remembered[ + tool_name] else 'deny_once' + return None + + def record_choice(self, tool_name: str, allowed: bool) -> None: + """Record a user's permission decision for future lookups.""" + if self.policy == 'remember_choice': + self._remembered[tool_name] = allowed + logger.info('Permission remembered for %s: %s', tool_name, + 'allowed' if allowed else 'denied') + + def reset(self) -> None: + """Clear all remembered decisions.""" + self._remembered.clear() + + +async def request_tool_permission( + connection, + session_id: str, + tool_call_id: str, + tool_name: str, + policy: PermissionPolicy, +) -> bool: + """Execute the permission flow for a tool call. + + Returns ``True`` if the tool should proceed, ``False`` if denied. + """ + decision = policy.auto_decision(tool_name) + if decision is not None: + return 'allow' in decision + + from acp.schema import PermissionOption, ToolCall + options = [ + PermissionOption( + option_id='allow_once', name='Allow', kind='allow_once'), + PermissionOption( + option_id='allow_always', name='Always allow', + kind='allow_always'), + PermissionOption( + option_id='deny_once', name='Deny', kind='reject_once'), + ] + tool_call = ToolCall( + tool_call_id=tool_call_id, + title=tool_name, + status='pending', + ) + + try: + result = await connection.request_permission( + session_id=session_id, + tool_call=tool_call, + options=options, + ) + outcome = getattr(result, 'outcome', None) + if outcome is None: + return True + + outcome_type = getattr(outcome, 'outcome', '') + if outcome_type == 'cancelled': + policy.record_choice(tool_name, False) + return False + + selected_id = getattr(outcome, 'option_id', '') + allowed = 'allow' in selected_id + policy.record_choice(tool_name, allowed) + return allowed + + except Exception: + logger.warning('Permission request failed for %s, auto-approving', + tool_name) + return True diff --git a/ms_agent/acp/proxy.py b/ms_agent/acp/proxy.py new file mode 100644 index 000000000..f857cad28 --- /dev/null +++ b/ms_agent/acp/proxy.py @@ -0,0 +1,385 @@ +import logging +import os +import sys +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional + +import yaml +from acp import (PROTOCOL_VERSION, Agent, InitializeResponse, + NewSessionResponse, PromptResponse, run_agent, + spawn_agent_process) +from acp.interfaces import Client +from acp.schema import (AgentCapabilities, ClientCapabilities, Implementation, + PromptCapabilities, SessionCapabilities, + SessionConfigOptionSelect, SessionConfigSelect, + SessionConfigSelectOption, SessionListCapabilities) +from ms_agent.utils.logger import get_logger + +from .errors import ConfigError, wrap_agent_error +from .proxy_session import ProxySessionStore + +logger = get_logger() + +_VERSION = '0.1.0' + + +@dataclass +class BackendConfig: + """Parsed configuration for a single backend agent.""" + name: str + command: str + args: list = field(default_factory=list) + description: str = '' + env: dict = field(default_factory=dict) + + +@dataclass +class ProxyConfig: + """Top-level proxy configuration parsed from YAML.""" + max_sessions: int = 8 + session_timeout: int = 3600 + default_backend: str = '' + backends: Dict[str, BackendConfig] = field(default_factory=dict) + + @classmethod + def from_yaml(cls, path: str) -> 'ProxyConfig': + if not path or not os.path.exists(path): + raise ConfigError(f'Proxy config not found: {path}') + with open(path) as f: + raw = yaml.safe_load(f) + if not isinstance(raw, dict): + raise ConfigError(f'Invalid proxy config format in {path}') + + proxy_section = raw.get('proxy', {}) + backends_section = raw.get('backends', {}) + + backends: Dict[str, BackendConfig] = {} + for name, cfg in backends_section.items(): + if not isinstance(cfg, dict) or 'command' not in cfg: + logger.warning('Skipping backend %s: missing "command"', name) + continue + backends[name] = BackendConfig( + name=name, + command=cfg['command'], + args=cfg.get('args', []), + description=cfg.get('description', f'ACP agent: {name}'), + env=cfg.get('env', {}), + ) + + default = proxy_section.get('default_backend', '') + if not default and backends: + default = next(iter(backends)) + + return cls( + max_sessions=proxy_section.get('max_sessions', 8), + session_timeout=proxy_section.get('session_timeout', 3600), + default_backend=default, + backends=backends, + ) + + +class _RelayClient(Client): + """ACP Client that transparently relays ``session_update`` and + ``request_permission`` from a backend agent back through the proxy's + own connection to the IDE. + + This is the core mechanism that makes streaming work without any + translator or delta-tracking logic. + """ + + def __init__(self, proxy_connection: Any, proxy_session_id: str): + self._conn = proxy_connection + self._proxy_sid = proxy_session_id + + async def session_update(self, session_id: str, update: Any, + **kwargs: Any) -> None: + await self._conn.session_update(self._proxy_sid, update) + + async def request_permission(self, options: list, session_id: str, + tool_call: Any, **kwargs: Any) -> Any: + return await self._conn.request_permission( + session_id=self._proxy_sid, + tool_call=tool_call, + options=options, + ) + + +class MSAgentACPProxy(Agent): + """ACP Proxy Server that dispatches sessions to backend ACP agents. + + Unlike ``MSAgentACPServer``, this module implements a pure ACP-to-ACP relay. + It presents itself as a standard ACP ``Agent`` to the IDE, + but internally dispatches every session to a backend agent subprocess via ``spawn_agent_process``. + Import boundary: this module MUST NOT import from ``ms_agent.agent``, + ``ms_agent.llm``, ``ms_agent.tools``, or ``ms_agent.acp.translator``. + """ + + def __init__(self, config: ProxyConfig) -> None: + self.config = config + self.session_store = ProxySessionStore( + max_sessions=config.max_sessions, + session_timeout=config.session_timeout, + ) + + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: + negotiated = min(protocol_version, PROTOCOL_VERSION) + logger.info( + 'ACP proxy initialize: client=%s negotiated_version=%d', + client_info.name if client_info else '', + negotiated, + ) + return InitializeResponse( + protocol_version=negotiated, + agent_capabilities=AgentCapabilities( + load_session=False, + prompt_capabilities=PromptCapabilities( + image=False, + audio=False, + embedded_context=True, + ), + session_capabilities=SessionCapabilities( + list=SessionListCapabilities(), ), + ), + agent_info=Implementation( + name='ms-agent-proxy', + title='MS-Agent Proxy', + version=_VERSION, + ), + auth_methods=[], + ) + + async def new_session( + self, + cwd: str, + mcp_servers: list | None = None, + **kwargs: Any, + ) -> NewSessionResponse: + backend_name = self.config.default_backend + try: + entry = await self._spawn_backend_session(backend_name, cwd) + except Exception as e: + logger.error('new_session failed: %s', e, exc_info=True) + raise + + config_options = self._build_config_options(backend_name) + return NewSessionResponse( + session_id=entry.id, + config_options=config_options, + ) + + async def prompt( + self, + prompt: list, + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + entry = self.session_store.get(session_id) + entry.is_running = True + try: + result = await entry.backend_conn.prompt( + session_id=entry.backend_sid, + prompt=prompt, + ) + return result + except Exception as e: + logger.error('Proxy prompt error: %s', e, exc_info=True) + raise wrap_agent_error(e) + finally: + entry.is_running = False + + async def cancel(self, session_id: str, **kwargs: Any) -> None: + try: + entry = self.session_store.get(session_id) + entry.request_cancel() + try: + await entry.backend_conn.cancel(session_id=entry.backend_sid) + except Exception: + logger.warning('Backend cancel failed for %s', session_id) + except Exception: + logger.warning('Cancel for unknown proxy session %s', session_id) + + async def list_sessions( + self, + cursor: str | None = None, + cwd: str | None = None, + **kwargs: Any, + ): + from acp.schema import ListSessionsResponse, SessionInfo + entries = self.session_store.list_sessions() + items = [ + SessionInfo( + session_id=e['session_id'], + cwd=e.get('cwd'), + ) for e in entries + ] + return ListSessionsResponse(sessions=items) + + async def set_config_option( + self, + config_id: str, + session_id: str, + value: str | bool, + **kwargs: Any, + ): + from acp.schema import SetSessionConfigOptionResponse + + if config_id == 'backend': + new_backend = str(value) + if new_backend not in self.config.backends: + raise ConfigError(f'Unknown backend: {new_backend}') + + entry = self.session_store.get(session_id) + old_cwd = entry.cwd + + await self.session_store.remove(session_id) + new_entry = await self._spawn_backend_session(new_backend, old_cwd) + + self.session_store._sessions[session_id] = new_entry + new_entry.id = session_id + + config_options = self._build_config_options(new_backend) + return SetSessionConfigOptionResponse( + config_options=config_options or []) + + entry = self.session_store.get(session_id) + try: + result = await entry.backend_conn.set_config_option( + config_id=config_id, + session_id=entry.backend_sid, + value=value, + ) + return result + except Exception: + logger.warning('Backend set_config_option failed', exc_info=True) + return SetSessionConfigOptionResponse(config_options=[]) + + def on_connect(self, conn) -> None: + self.connection = conn + + async def _shutdown(self) -> None: + await self.session_store.close_all() + + async def _spawn_backend_session( + self, + backend_name: str, + cwd: str, + ): + """Spawn a backend agent process, initialize it, create a session, + and register everything in the proxy session store.""" + import uuid as _uuid + + bcfg = self.config.backends.get(backend_name) + if bcfg is None: + raise ConfigError(f'Backend not configured: {backend_name}') + + proxy_sid = f'pxy_{_uuid.uuid4().hex[:12]}' + + relay = _RelayClient(self.connection, proxy_sid) + + env = dict(os.environ) + env.update(bcfg.env) + + ctx = spawn_agent_process(relay, bcfg.command, *bcfg.args, env=env) + conn, proc = await ctx.__aenter__() + + try: + await conn.initialize(protocol_version=PROTOCOL_VERSION) + session_resp = await conn.new_session(cwd=cwd, mcp_servers=[]) + backend_sid = session_resp.session_id + except Exception: + try: + await ctx.__aexit__(None, None, None) + except Exception: + pass + raise + + entry = self.session_store.register( + backend_name=backend_name, + backend_sid=backend_sid, + backend_conn=conn, + backend_proc=proc, + ctx_manager=ctx, + cwd=cwd, + ) + old_id = entry.id + entry.id = proxy_sid + self.session_store._sessions.pop(old_id, None) + self.session_store._sessions[proxy_sid] = entry + + return entry + + def _build_config_options( + self, + current_backend: str, + ) -> list | None: + if len(self.config.backends) <= 1: + return None + + values = [ + SessionConfigSelectOption( + value=name, + name=f'{name}: {bcfg.description}', + ) for name, bcfg in self.config.backends.items() + ] + return [ + SessionConfigOptionSelect( + type='select', + id='backend', + name='Backend Agent', + category='model', + current_value=current_backend, + options=values, + ), + ] + + +def configure_proxy_logging(log_file: str | None = None) -> None: + """Set up logging so nothing leaks onto stdout (the ACP wire).""" + handler: logging.Handler + if log_file: + handler = logging.FileHandler(log_file) + else: + handler = logging.StreamHandler(sys.stderr) + + fmt = logging.Formatter( + '%(asctime)s [%(name)s] %(levelname)s: %(message)s') + handler.setFormatter(fmt) + + root = logging.getLogger() + root.handlers.clear() + root.addHandler(handler) + root.setLevel(logging.INFO) + + +def serve_proxy( + config_path: str, + log_file: str | None = None, +) -> None: + """Entry point: run the ACP proxy server over stdio.""" + configure_proxy_logging(log_file) + + config = ProxyConfig.from_yaml(config_path) + logger.info( + 'Proxy starting: %d backends configured [%s], default=%s', + len(config.backends), + ', '.join(config.backends.keys()), + config.default_backend, + ) + + proxy = MSAgentACPProxy(config) + + import asyncio + asyncio.run(_run_proxy(proxy)) + + +async def _run_proxy(proxy: MSAgentACPProxy) -> None: + try: + await run_agent(proxy) + finally: + await proxy._shutdown() diff --git a/ms_agent/acp/proxy_session.py b/ms_agent/acp/proxy_session.py new file mode 100644 index 000000000..5cbab431c --- /dev/null +++ b/ms_agent/acp/proxy_session.py @@ -0,0 +1,193 @@ +import asyncio +import uuid +from dataclasses import dataclass, field +from time import monotonic +from typing import Any, Dict, List, Optional + +from ms_agent.utils.logger import get_logger + +from .errors import MaxSessionsError, SessionNotFoundError + +logger = get_logger() + + +@dataclass +class ProxySessionEntry: + """A proxy session maps a client-facing session ID to a backend ACP + connection. No LLM or agent instance is held here.""" + + id: str + backend_name: str + backend_sid: str + backend_conn: Any + backend_proc: Any + ctx_manager: Any + cwd: str + created_at: float + last_activity: float + is_running: bool = False + _cancel_event: asyncio.Event = field(default_factory=asyncio.Event) + + def touch(self) -> None: + self.last_activity = monotonic() + + def request_cancel(self) -> None: + self._cancel_event.set() + + @property + def cancelled(self) -> bool: + return self._cancel_event.is_set() + + +class ProxySessionStore: + """Manages proxy session lifecycle with LRU eviction and TTL cleanup. + + Parameters: + max_sessions: Upper bound on concurrent proxy sessions. + session_timeout: Seconds of inactivity before a session is eligible + for eviction. + cleanup_interval: Seconds between periodic cleanup sweeps. + """ + + def __init__( + self, + max_sessions: int = 8, + session_timeout: int = 3600, + cleanup_interval: int = 300, + ): + self.max_sessions = max_sessions + self.session_timeout = session_timeout + self.cleanup_interval = cleanup_interval + self._sessions: Dict[str, ProxySessionEntry] = {} + self._cleanup_task: Optional[asyncio.Task] = None + + def register( + self, + backend_name: str, + backend_sid: str, + backend_conn: Any, + backend_proc: Any, + ctx_manager: Any, + cwd: str, + ) -> ProxySessionEntry: + """Register a newly established backend connection as a proxy session. + + Raises ``MaxSessionsError`` synchronously if the limit is reached and + no idle session can be evicted (eviction itself is sync here because + the async cleanup is best-effort). + """ + if len(self._sessions) >= self.max_sessions: + evicted_id = self._evict_lru() + if evicted_id is None: + raise MaxSessionsError(self.max_sessions) + self._force_remove(evicted_id) + + now = monotonic() + session_id = f'pxy_{uuid.uuid4().hex[:12]}' + entry = ProxySessionEntry( + id=session_id, + backend_name=backend_name, + backend_sid=backend_sid, + backend_conn=backend_conn, + backend_proc=backend_proc, + ctx_manager=ctx_manager, + cwd=cwd, + created_at=now, + last_activity=now, + ) + self._sessions[session_id] = entry + self._ensure_cleanup_running() + logger.info( + 'Proxy session created: %s -> backend %s (sid=%s)', + session_id, + backend_name, + backend_sid, + ) + return entry + + def get(self, session_id: str) -> ProxySessionEntry: + try: + entry = self._sessions[session_id] + except KeyError: + raise SessionNotFoundError(session_id) + entry.touch() + return entry + + def list_sessions(self) -> List[Dict[str, Any]]: + return [{ + 'session_id': e.id, + 'backend': e.backend_name, + 'cwd': e.cwd, + 'is_running': e.is_running, + } for e in self._sessions.values()] + + async def remove(self, session_id: str) -> None: + await self._cleanup_session(session_id) + + async def close_all(self) -> None: + for sid in list(self._sessions): + await self._cleanup_session(sid) + if self._cleanup_task and not self._cleanup_task.done(): + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + def _evict_lru(self) -> Optional[str]: + idle = [(sid, s) for sid, s in self._sessions.items() + if not s.is_running] + if not idle: + return None + return min(idle, key=lambda x: x[1].last_activity)[0] + + def _force_remove(self, session_id: str) -> None: + """Synchronous best-effort removal (subprocess killed but not awaited).""" + entry = self._sessions.pop(session_id, None) + if entry is None: + return + try: + if entry.ctx_manager is not None: + asyncio.ensure_future( + entry.ctx_manager.__aexit__(None, None, None)) + except Exception: + logger.warning( + 'Error force-removing proxy session %s', + session_id, + exc_info=True) + logger.info('Proxy session force-removed: %s', session_id) + + async def _cleanup_session(self, session_id: str) -> None: + entry = self._sessions.pop(session_id, None) + if entry is None: + return + if entry.ctx_manager is not None: + try: + await entry.ctx_manager.__aexit__(None, None, None) + except Exception: + logger.warning( + 'Error cleaning up proxy session %s', + session_id, + exc_info=True) + logger.info('Proxy session removed: %s', session_id) + + def _ensure_cleanup_running(self) -> None: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + return + if self._cleanup_task is None or self._cleanup_task.done(): + self._cleanup_task = loop.create_task(self._periodic_cleanup()) + + async def _periodic_cleanup(self) -> None: + while True: + await asyncio.sleep(self.cleanup_interval) + now = monotonic() + expired = [ + sid for sid, s in self._sessions.items() + if (now - s.last_activity > self.session_timeout + and not s.is_running) + ] + for sid in expired: + logger.info('Evicting timed-out proxy session %s', sid) + await self._cleanup_session(sid) diff --git a/ms_agent/acp/registry.py b/ms_agent/acp/registry.py new file mode 100644 index 000000000..96719d29a --- /dev/null +++ b/ms_agent/acp/registry.py @@ -0,0 +1,62 @@ +import os +import sys +from typing import Any, Dict + +import json +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_DEFAULT_VERSION = '0.1.0' + + +def generate_agent_manifest( + config_path: str | None = None, + output_path: str = 'agent.json', + version: str = _DEFAULT_VERSION, + title: str = 'MS-Agent', + description: + str = 'Lightweight framework for empowering agents with autonomous exploration', +) -> Dict[str, Any]: + """Build and optionally write an ``agent.json`` manifest. + + The manifest follows the ACP Agent Registry specification so that + tools like Zed's agent picker can auto-discover ms-agent. + """ + exe = 'ms-agent' + args = ['acp'] + if config_path: + args.extend(['--config', config_path]) + + manifest: Dict[str, Any] = { + 'name': 'ms-agent', + 'title': title, + 'version': version, + 'description': description, + 'protocol': 'acp', + 'protocolVersion': 1, + 'transport': { + 'type': 'stdio', + 'command': exe, + 'args': args, + }, + 'capabilities': { + 'loadSession': False, + 'promptCapabilities': { + 'image': False, + 'audio': False, + 'embeddedContext': True, + }, + 'sessionCapabilities': { + 'list': {}, + }, + }, + } + + if output_path: + abs_path = os.path.abspath(output_path) + with open(abs_path, 'w') as f: + json.dump(manifest, f, indent=2) + logger.info('Agent manifest written to %s', abs_path) + + return manifest diff --git a/ms_agent/acp/server.py b/ms_agent/acp/server.py new file mode 100644 index 000000000..15ac7b9ee --- /dev/null +++ b/ms_agent/acp/server.py @@ -0,0 +1,398 @@ +import io +import logging +import os +import sys +from contextlib import contextmanager +from typing import Any + +import json +from acp import (PROTOCOL_VERSION, Agent, InitializeResponse, + NewSessionResponse, PromptResponse, run_agent, text_block) +from acp.schema import (AgentCapabilities, ClientCapabilities, Implementation, + PermissionOption, PromptCapabilities, + SessionCapabilities, SessionListCapabilities) +from ms_agent.utils.logger import get_logger + +from .config import (apply_config_option, build_config_options, + build_session_modes) +from .errors import wrap_agent_error +from .session_store import ACPSessionStore +from .translator import ACPTranslator + +logger = get_logger() + +SUPPORTED_PROTOCOL_VERSION: int = PROTOCOL_VERSION + +_VERSION = '0.1.0' + + +def configure_acp_logging(log_file: str | None = None) -> None: + """Set up logging so nothing leaks onto stdout (the ACP wire). + + By default logs go to *stderr*; pass ``log_file`` to write to disk + instead. + """ + handler: logging.Handler + if log_file: + handler = logging.FileHandler(log_file) + else: + handler = logging.StreamHandler(sys.stderr) + + fmt = logging.Formatter( + '%(asctime)s [%(name)s] %(levelname)s: %(message)s') + handler.setFormatter(fmt) + + root = logging.getLogger() + root.handlers.clear() + root.addHandler(handler) + root.setLevel(logging.INFO) + + +class MSAgentACPServer(Agent): + """ACP Server that wraps ms-agent's ``LLMAgent`` (or any project agent).""" + + def __init__( + self, + config_path: str, + trust_remote_code: bool = False, + max_sessions: int = 8, + session_timeout: int = 3600, + ) -> None: + self.config_path = config_path + self.trust_remote_code = trust_remote_code + self.session_store = ACPSessionStore( + max_sessions=max_sessions, + session_timeout=session_timeout, + ) + self._translators: dict[str, ACPTranslator] = {} + + def _get_translator(self, session_id: str) -> ACPTranslator: + if session_id not in self._translators: + self._translators[session_id] = ACPTranslator() + return self._translators[session_id] + + @staticmethod + @contextmanager + def _suppress_stdout(): + """Redirect stdout to devnull while running agent logic. + + LLMAgent.step() writes streaming tokens to sys.stdout, which + would corrupt the ACP JSON-RPC wire when running over stdio. + """ + real_stdout = sys.stdout + sys.stdout = open(os.devnull, 'w') + try: + yield + finally: + sys.stdout.close() + sys.stdout = real_stdout + + async def initialize( + self, + protocol_version: int, + client_capabilities: ClientCapabilities | None = None, + client_info: Implementation | None = None, + **kwargs: Any, + ) -> InitializeResponse: + negotiated = min(protocol_version, SUPPORTED_PROTOCOL_VERSION) + logger.info( + 'ACP initialize: client=%s negotiated_version=%d', + client_info.name if client_info else '', + negotiated, + ) + return InitializeResponse( + protocol_version=negotiated, + agent_capabilities=AgentCapabilities( + load_session=True, + prompt_capabilities=PromptCapabilities( + image=False, + audio=False, + embedded_context=True, + ), + session_capabilities=SessionCapabilities( + list=SessionListCapabilities(), ), + ), + agent_info=Implementation( + name='ms-agent', + title='MS-Agent', + version=_VERSION, + ), + auth_methods=[], + ) + + async def new_session( + self, + cwd: str, + mcp_servers: list | None = None, + **kwargs: Any, + ) -> NewSessionResponse: + meta = kwargs.get('_meta') or kwargs.get('field_meta') + try: + session = await self.session_store.create( + config_path=self.config_path, + cwd=cwd, + trust_remote_code=self.trust_remote_code, + mcp_servers=mcp_servers, + meta=meta, + ) + except Exception as e: + logger.error('new_session failed: %s', e, exc_info=True) + raise + config_options = build_config_options(session.config) + modes = build_session_modes() + return NewSessionResponse( + session_id=session.id, + config_options=config_options, + modes=modes, + ) + + async def prompt( + self, + prompt: list, + session_id: str, + **kwargs: Any, + ) -> PromptResponse: + session = self.session_store.get(session_id) + translator = self._get_translator(session_id) + + is_first_turn = len(session.messages) == 0 + prior_msg_count = len(session.messages) + + translator.reset_turn(prior_msg_count) + + if is_first_turn: + user_text = translator.prompt_to_messages(prompt)[0].content + run_input = user_text + else: + run_input = translator.prompt_to_messages(prompt, session.messages) + + session.is_running = True + session._cancel_event.clear() + + try: + with self._suppress_stdout(): + result = await session.agent.run(run_input, stream=True) + if hasattr(result, '__aiter__'): + async for chunk in result: + session.messages = chunk + updates = translator.messages_to_updates(chunk) + for update in updates: + try: + await self.connection.session_update( + session_id, update) + except Exception as send_err: + logger.warning( + 'Failed to send session_update: %s', + send_err) + if session.cancelled: + break + elif isinstance(result, list): + session.messages = result + updates = translator.messages_to_updates(result) + for update in updates: + try: + await self.connection.session_update( + session_id, update) + except Exception as send_err: + logger.warning('Failed to send session_update: %s', + send_err) + + plan_updates = self._extract_plan_updates(session, translator) + for pu in plan_updates: + try: + await self.connection.session_update(session_id, pu) + except Exception: + pass + + stop = translator.map_stop_reason(session) + return PromptResponse(stop_reason=stop) + except Exception as e: + logger.error('Error during prompt: %s', e, exc_info=True) + raise wrap_agent_error(e) + finally: + session.is_running = False + + @staticmethod + def _extract_plan_updates(session, translator) -> list: + """Extract plan updates from agent state (todo tool, etc.).""" + agent = session.agent + steps = [] + + if hasattr(agent, 'runtime') and agent.runtime: + todo_items = getattr(agent.runtime, 'todo_items', None) + if todo_items and isinstance(todo_items, list): + for item in todo_items: + if isinstance(item, dict): + steps.append({ + 'description': + item.get('description', item.get('content', '')), + 'status': + item.get('status', 'pending'), + 'priority': + item.get('priority', 'medium'), + }) + + if not steps: + for msg in reversed(session.messages or []): + if (msg.role == 'tool' and msg.name + in ('todo_write', 'todo_read', 'todo', 'split_task') + and msg.content): + try: + data = json.loads(msg.content) + todos = None + if isinstance(data, dict): + todos = data.get('todos', None) + if isinstance(data, list): + todos = data + if todos and isinstance(todos, list): + for item in todos: + if isinstance(item, dict): + steps.append({ + 'description': + item.get( + 'content', + item.get('description', + item.get('task', ''))), + 'status': + item.get('status', 'pending'), + 'priority': + item.get('priority', 'medium'), + }) + break + except (json.JSONDecodeError, TypeError): + pass + + if steps: + return [translator.build_plan_update(steps)] + return [] + + async def cancel(self, session_id: str, **kwargs: Any) -> None: + try: + session = self.session_store.get(session_id) + session.request_cancel() + logger.info('Session %s cancel requested', session_id) + except Exception: + logger.warning('Cancel for unknown session %s', session_id) + + async def load_session( + self, + cwd: str, + session_id: str, + mcp_servers: list | None = None, + **kwargs: Any, + ): + from acp.schema import LoadSessionResponse + try: + session = self.session_store.get(session_id) + except Exception: + logger.info('load_session: session %s not found, creating new', + session_id) + meta = kwargs.get('_meta') or kwargs.get('field_meta') + session = await self.session_store.create( + config_path=self.config_path, + cwd=cwd, + trust_remote_code=self.trust_remote_code, + mcp_servers=mcp_servers, + meta=meta, + ) + + translator = self._get_translator(session.id) + translator.reset_turn() + for i, msg in enumerate(session.messages): + partial = session.messages[:i + 1] + updates = translator.messages_to_updates(partial) + for update in updates: + try: + await self.connection.session_update(session.id, update) + except Exception: + pass + + config_options = build_config_options(session.config) + modes = build_session_modes() + return LoadSessionResponse( + session_id=session.id, + config_options=config_options, + modes=modes, + ) + + async def list_sessions( + self, + cursor: str | None = None, + cwd: str | None = None, + **kwargs: Any, + ): + from acp.schema import ListSessionsResponse, SessionInfo + entries = self.session_store.list_sessions() + items = [] + for e in entries: + items.append( + SessionInfo( + session_id=e['session_id'], + cwd=e.get('cwd'), + )) + return ListSessionsResponse(sessions=items) + + async def set_config_option( + self, + config_id: str, + session_id: str, + value: str | bool, + **kwargs: Any, + ): + from acp.schema import SetSessionConfigOptionResponse + session = self.session_store.get(session_id) + apply_config_option(session.config, config_id, str(value)) + new_options = build_config_options(session.config) or [] + return SetSessionConfigOptionResponse(config_options=new_options) + + async def set_session_mode( + self, + mode_id: str, + session_id: str, + **kwargs: Any, + ): + from acp.schema import SetSessionModeResponse, CurrentModeUpdate + # _session = self.session_store.get(session_id) + await self.connection.session_update( + session_id, + CurrentModeUpdate( + session_update='current_mode_update', + current_mode_id=mode_id, + ), + ) + return SetSessionModeResponse() + + def on_connect(self, conn) -> None: + self.connection = conn + + async def _shutdown(self) -> None: + await self.session_store.close_all() + + +def serve( + config_path: str, + trust_remote_code: bool = False, + max_sessions: int = 8, + session_timeout: int = 3600, + log_file: str | None = None, +) -> None: + """Entry point: run the ACP server over stdio.""" + configure_acp_logging(log_file) + logger.info( + 'serve() called: config_path=%s trust_remote_code=%s ' + 'sys.argv=%s', config_path, trust_remote_code, sys.argv) + server = MSAgentACPServer( + config_path=config_path, + trust_remote_code=trust_remote_code, + max_sessions=max_sessions, + session_timeout=session_timeout, + ) + import asyncio + asyncio.run(_run_server(server)) + + +async def _run_server(server: MSAgentACPServer) -> None: + try: + await run_agent(server) + finally: + await server._shutdown() diff --git a/ms_agent/acp/session_store.py b/ms_agent/acp/session_store.py new file mode 100644 index 000000000..f37ff1dd5 --- /dev/null +++ b/ms_agent/acp/session_store.py @@ -0,0 +1,183 @@ +import asyncio +import os +import uuid +from dataclasses import dataclass, field +from time import monotonic +from typing import Any, Dict, List, Optional + +from ms_agent.agent.base import Agent +from ms_agent.agent.loader import AgentLoader +from ms_agent.config.config import Config +from ms_agent.config.env import Env +from ms_agent.llm.utils import Message +from ms_agent.utils.logger import get_logger +from omegaconf import DictConfig, OmegaConf + +from .errors import ConfigError, MaxSessionsError, SessionNotFoundError + +logger = get_logger() + + +@dataclass +class ACPSessionEntry: + """In-memory representation of a single ACP session.""" + + id: str + agent: Agent + config: DictConfig + config_path: str + cwd: str + created_at: float + last_activity: float + messages: List[Message] = field(default_factory=list) + is_running: bool = False + _cancel_event: asyncio.Event = field(default_factory=asyncio.Event) + + def touch(self) -> None: + self.last_activity = monotonic() + + def request_cancel(self) -> None: + self._cancel_event.set() + if self.agent.runtime is not None: + self.agent.runtime.should_stop = True + + @property + def cancelled(self) -> bool: + return self._cancel_event.is_set() + + +class ACPSessionStore: + """Manages ACP session lifecycle with concurrency and timeout controls. + + Parameters: + max_sessions: Upper bound on concurrent in-memory sessions. + session_timeout: Seconds of inactivity before a session becomes + eligible for eviction. + cleanup_interval: Seconds between periodic cleanup sweeps. + """ + + def __init__( + self, + max_sessions: int = 8, + session_timeout: int = 3600, + cleanup_interval: int = 300, + ): + self.max_sessions = max_sessions + self.session_timeout = session_timeout + self.cleanup_interval = cleanup_interval + self._sessions: Dict[str, ACPSessionEntry] = {} + self._cleanup_task: Optional[asyncio.Task] = None + + async def create( + self, + config_path: str, + cwd: str, + trust_remote_code: bool = False, + mcp_servers: list | None = None, + meta: dict | None = None, + ) -> ACPSessionEntry: + """Create a new session backed by a freshly loaded agent.""" + if len(self._sessions) >= self.max_sessions: + evicted_id = self._evict_lru() + if evicted_id is None: + raise MaxSessionsError(self.max_sessions) + await self._cleanup_session(evicted_id) + + if not config_path or not os.path.exists(config_path): + raise ConfigError(f'Config not found: {config_path}') + + config = Config.from_task(config_path) + agent = AgentLoader.build( + config_dir_or_id=config_path, + config=config, + trust_remote_code=trust_remote_code, + ) + + now = monotonic() + session_id = f'ses_{uuid.uuid4().hex[:12]}' + entry = ACPSessionEntry( + id=session_id, + agent=agent, + config=config, + config_path=config_path, + cwd=cwd, + created_at=now, + last_activity=now, + ) + self._sessions[session_id] = entry + self._ensure_cleanup_running() + logger.info('ACP session created: %s (config=%s)', session_id, + config_path) + return entry + + def get(self, session_id: str) -> ACPSessionEntry: + """Return a session entry or raise ``SessionNotFoundError``.""" + try: + entry = self._sessions[session_id] + except KeyError: + raise SessionNotFoundError(session_id) + entry.touch() + return entry + + def list_sessions(self) -> List[Dict[str, Any]]: + """Return metadata for every active session.""" + result = [] + for sid, entry in self._sessions.items(): + result.append({ + 'session_id': sid, + 'config_path': entry.config_path, + 'cwd': entry.cwd, + 'is_running': entry.is_running, + }) + return result + + async def remove(self, session_id: str) -> None: + await self._cleanup_session(session_id) + + async def close_all(self) -> None: + for sid in list(self._sessions): + await self._cleanup_session(sid) + if self._cleanup_task and not self._cleanup_task.done(): + self._cleanup_task.cancel() + try: + await self._cleanup_task + except asyncio.CancelledError: + pass + + def _evict_lru(self) -> Optional[str]: + """Pick the least-recently-used *idle* session for eviction.""" + idle = [(sid, s) for sid, s in self._sessions.items() + if not s.is_running] + if not idle: + return None + return min(idle, key=lambda x: x[1].last_activity)[0] + + async def _cleanup_session(self, session_id: str) -> None: + entry = self._sessions.pop(session_id, None) + if entry is None: + return + try: + if hasattr(entry.agent, 'cleanup_tools'): + await entry.agent.cleanup_tools() + except Exception: + logger.warning( + 'Error cleaning up session %s', session_id, exc_info=True) + logger.info('ACP session removed: %s', session_id) + + def _ensure_cleanup_running(self) -> None: + if self._cleanup_task is None or self._cleanup_task.done(): + self._cleanup_task = asyncio.create_task(self._periodic_cleanup()) + + async def _periodic_cleanup(self) -> None: + """Background loop that evicts timed-out sessions.""" + while True: + await asyncio.sleep(self.cleanup_interval) + now = monotonic() + expired = [ + sid for sid, s in self._sessions.items() + if (now - s.last_activity > self.session_timeout + and not s.is_running) + ] + for sid in expired: + logger.info('Evicting timed-out session %s', sid) + await self._cleanup_session(sid) diff --git a/ms_agent/acp/translator.py b/ms_agent/acp/translator.py new file mode 100644 index 000000000..7c9011b2a --- /dev/null +++ b/ms_agent/acp/translator.py @@ -0,0 +1,466 @@ +import uuid +from typing import Any, Dict, List, Optional + +import json +from acp import (plan_entry, start_edit_tool_call, start_read_tool_call, + start_tool_call, text_block, tool_content, tool_diff_content, + update_agent_message_text, update_agent_thought_text, + update_plan, update_tool_call) +from acp.schema import AgentPlanUpdate, ToolCallLocation +from ms_agent.llm.utils import Message +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +_TOOL_KIND_MAP: Dict[str, str] = { + 'code_executor': 'execute', + 'local_code_executor': 'execute', + 'web_search': 'search', + 'arxiv_search': 'search', + 'exa_search': 'search', + 'serpapi_search': 'search', + 'google_search': 'search', + 'filesystem': 'read', + 'file_read': 'read', + 'file_write': 'edit', + 'todo': 'think', + 'evidence_store': 'think', + 'split_task': 'think', + 'browser': 'fetch', + 'web_browser': 'fetch', + 'mcp_client': 'other', +} + +_FILE_EDIT_METHODS = frozenset({ + 'write_file', + 'replace_file_contents', + 'file_operation', + 'create_file', + 'edit_file', + 'patch_file', + 'insert_lines', + 'delete_lines', +}) + +_FILE_READ_METHODS = frozenset({ + 'read_file', + 'list_files', + 'file_operation', + 'list_directory', + 'read_directory', +}) + +_EXEC_METHODS = frozenset({ + 'execute_code', + 'run_code', + 'execute', + 'run_command', + 'shell', + 'bash', +}) + + +def _parse_tool_args(arguments: str) -> dict: + """Best-effort parse tool call arguments JSON.""" + if not arguments: + return {} + try: + parsed = json.loads(arguments) + return parsed if isinstance(parsed, dict) else {} + except (json.JSONDecodeError, TypeError): + return {} + + +def _infer_kind_from_args(tool_name: str, args: dict) -> str: + """Infer ACP tool kind from tool name and parsed arguments.""" + if tool_name in _TOOL_KIND_MAP: + kind = _TOOL_KIND_MAP[tool_name] + if tool_name == 'filesystem': + method = args.get('method', '') or args.get('operation', '') + if method in _FILE_EDIT_METHODS: + return 'edit' + if method in _FILE_READ_METHODS: + return 'read' + return kind + + lower = tool_name.lower() + if any( + kw in lower + for kw in ('write', 'edit', 'create', 'patch', 'replace')): + return 'edit' + if any(kw in lower for kw in ('read', 'list', 'get', 'show', 'cat')): + return 'read' + if any(kw in lower for kw in ('search', 'find', 'grep', 'query')): + return 'search' + if any( + kw in lower + for kw in ('exec', 'run', 'shell', 'code', 'bash', 'command')): + return 'execute' + if any(kw in lower for kw in ('delete', 'remove', 'rm')): + return 'delete' + if any(kw in lower for kw in ('move', 'rename', 'mv')): + return 'move' + if any( + kw in lower + for kw in ('fetch', 'download', 'browse', 'http', 'url')): + return 'fetch' + if any( + kw in lower + for kw in ('think', 'plan', 'reason', 'todo', 'evidence')): + return 'think' + return 'other' + + +def _build_title(tool_name: str, args: dict, kind: str) -> str: + """Build a human-readable title for the IDE's tool call UI.""" + path = ( + args.get('path') or args.get('file_path') or args.get('filename') + or '') + method = args.get('method') or args.get('operation') or '' + + if kind == 'edit': + if path: + return f'Edit {path}' + return f'{tool_name}: {method}' if method else tool_name + + if kind == 'read': + if path: + return f'Read {path}' + return f'{tool_name}: {method}' if method else tool_name + + if kind == 'execute': + code = args.get('code', '') or args.get('command', '') + if code: + preview = code.strip().split('\n')[0][:80] + return f'Run: {preview}' + return f'Execute {tool_name}' + + if kind == 'search': + query = args.get('query', '') or args.get('search_query', '') + if query: + return f'Search: {query[:60]}' + return f'Search ({tool_name})' + + if kind == 'think': + return f'Thinking: {tool_name}' + + if kind == 'fetch': + url = args.get('url', '') or args.get('uri', '') + if url: + return f'Fetch: {url[:60]}' + return f'Fetch ({tool_name})' + + if kind == 'delete': + if path: + return f'Delete {path}' + return f'Delete ({tool_name})' + + return tool_name + + +def _extract_locations(args: dict) -> list[ToolCallLocation] | None: + """Extract file locations from tool arguments for IDE follow-along.""" + path = ( + args.get('path') or args.get('file_path') or args.get('filename') + or '') + if not path: + return None + line = args.get('line') or args.get('line_number') + return [ + ToolCallLocation( + path=path, + line=int(line) if line is not None else None, + ) + ] + + +def _extract_file_path(args: dict) -> str: + """Extract file path from tool arguments.""" + return (args.get('path') or args.get('file_path') or args.get('filename') + or '') + + +def _try_parse_diff_from_result( + tool_name: str, + args: dict, + result_text: str, +) -> Optional[dict]: + """Attempt to extract diff info (path, old_text, new_text) from tool + arguments and result for file-edit operations. + + Returns a dict with path/old_text/new_text if applicable, else None. + """ + path = _extract_file_path(args) + if not path: + return None + + method = (args.get('method') or args.get('operation', '')).lower() + + if method == 'write' or tool_name == 'file_write': + new_text = args.get('content', '') + if new_text: + return {'path': path, 'old_text': None, 'new_text': new_text} + + if 'write_file' in tool_name or method == 'write_file': + new_text = args.get('content', '') + if new_text: + return {'path': path, 'old_text': None, 'new_text': new_text} + + if 'replace' in tool_name or 'replace' in method: + source = args.get('source', '') or args.get('old_text', '') + target = args.get('target', '') or args.get('new_text', '') + if source and target: + return {'path': path, 'old_text': source, 'new_text': target} + + return None + + +class ACPTranslator: + """Stateful translator: bidirectional mapping between ACP protocol schema + and ms-agent Message objects. + + Create one instance per session so delta tracking stays session-scoped. + """ + + def __init__(self) -> None: + self._last_content_len: int = 0 + self._last_reasoning_len: int = 0 + self._emitted_tool_ids: set[str] = set() + self._completed_tool_ids: set[str] = set() + self._tool_args_cache: dict[str, dict] = {} + self._tool_name_cache: dict[str, str] = {} + self._last_seen_msg_count: int = 0 + + def reset_turn(self, prior_msg_count: int = 0) -> None: + """Reset per-turn delta tracking. Call at the start of each prompt. + + ``prior_msg_count`` is the number of messages that already existed + before this turn. Setting it correctly prevents the translator from + replaying old assistant content as new deltas in multi-turn sessions. + """ + self._last_content_len = 0 + self._last_reasoning_len = 0 + self._emitted_tool_ids.clear() + self._completed_tool_ids.clear() + self._tool_args_cache.clear() + self._tool_name_cache.clear() + self._last_seen_msg_count = prior_msg_count + + @staticmethod + def prompt_to_messages( + prompt: list, + existing_messages: List[Message] | None = None, + ) -> List[Message]: + """Convert an ACP prompt (list of ContentBlocks) to ms-agent Messages. + + If ``existing_messages`` is provided the new user message is appended; + otherwise a fresh list is returned. + """ + parts: list[str] = [] + for block in prompt: + block_type = getattr(block, 'type', None) + if block_type == 'text': + parts.append(block.text) + elif block_type == 'resource': + res = block.resource + if hasattr(res, 'text'): + parts.append( + f'[Resource: {getattr(res, "uri", "")}]\n{res.text}') + elif hasattr(res, 'blob'): + parts.append( + f'[Binary resource: {getattr(res, "uri", "")}]') + elif block_type == 'resource_link': + uri = getattr(block, 'uri', '') + parts.append(f'[Resource link: {uri}]') + elif block_type == 'image': + parts.append('[Image content attached]') + else: + parts.append(str(block)) + + user_text = '\n'.join(parts) + user_msg = Message(role='user', content=user_text) + + if existing_messages is not None: + existing_messages.append(user_msg) + return existing_messages + return [user_msg] + + def messages_to_updates( + self, + messages: List[Message], + ) -> list: + """Diff the current message list against what was already sent + and return a list of ACP SessionUpdate objects for the new content. + + Processes ALL new messages since the last call, not just the last one. + This ensures tool results from parallel_tool_call are not missed. + """ + updates: list = [] + if not messages: + return updates + + start = max(self._last_seen_msg_count, 0) + new_messages = messages[start:] + self._last_seen_msg_count = len(messages) + + if not new_messages: + last_msg = messages[-1] + if last_msg.role == 'assistant': + updates.extend(self._translate_assistant(last_msg)) + return updates + + for msg in new_messages: + if msg.role == 'assistant': + updates.extend(self._translate_assistant(msg)) + elif msg.role == 'tool': + updates.extend(self._translate_tool_result(msg)) + + return updates + + def _translate_assistant(self, msg: Message) -> list: + updates: list = [] + content = msg.content if isinstance(msg.content, str) else '' + reasoning = msg.reasoning_content or '' + + if reasoning and len(reasoning) > self._last_reasoning_len: + delta = reasoning[self._last_reasoning_len:] + self._last_reasoning_len = len(reasoning) + updates.append(update_agent_thought_text(delta)) + + if content and len(content) > self._last_content_len: + delta = content[self._last_content_len:] + self._last_content_len = len(content) + updates.append(update_agent_message_text(delta)) + + for tc in (msg.tool_calls or []): + tc_id = tc.get('id', '') or f'tc_{uuid.uuid4().hex[:8]}' + if tc_id in self._emitted_tool_ids: + continue + self._emitted_tool_ids.add(tc_id) + + tool_name = tc.get('tool_name', 'unknown') + raw_args = tc.get('arguments', '') + args = _parse_tool_args(raw_args) + self._tool_args_cache[tc_id] = args + self._tool_name_cache[tc_id] = tool_name + + kind = _infer_kind_from_args(tool_name, args) + title = _build_title(tool_name, args, kind) + locations = _extract_locations(args) + path = _extract_file_path(args) + + if kind == 'edit' and path: + content_preview = args.get('content', '') + updates.append( + start_edit_tool_call( + tool_call_id=tc_id, + title=title, + path=path, + content=content_preview, + )) + elif kind == 'read' and path: + updates.append( + start_read_tool_call( + tool_call_id=tc_id, + title=title, + path=path, + )) + else: + updates.append( + start_tool_call( + tool_call_id=tc_id, + title=title, + kind=kind, + status='in_progress', + locations=locations, + raw_input=raw_args if raw_args else None, + )) + + return updates + + def _translate_tool_result(self, msg: Message) -> list: + updates: list = [] + tc_id = msg.tool_call_id or '' + if not tc_id or tc_id in self._completed_tool_ids: + return updates + self._completed_tool_ids.add(tc_id) + + result_text = ( + msg.content if isinstance(msg.content, str) else str(msg.content)) + tool_name = self._tool_name_cache.get(tc_id, msg.name or '') + args = self._tool_args_cache.get(tc_id, {}) + kind = _infer_kind_from_args(tool_name, args) + is_error = self._looks_like_error(result_text) + + content_items: list = [] + + if kind == 'edit' and not is_error: + diff = _try_parse_diff_from_result(tool_name, args, result_text) + if diff: + content_items.append( + tool_diff_content( + path=diff['path'], + new_text=diff['new_text'], + old_text=diff.get('old_text'), + )) + if result_text: + content_items.append(tool_content(text_block(result_text))) + elif result_text: + content_items.append(tool_content(text_block(result_text))) + + status = 'failed' if is_error else 'completed' + + updates.append( + update_tool_call( + tool_call_id=tc_id, + status=status, + content=content_items if content_items else None, + raw_output=result_text or None, + )) + return updates + + @staticmethod + def _looks_like_error(text: str) -> bool: + if not text: + return False + lower = text.lower()[:200] + error_markers = ('error:', 'failed', 'exception', 'traceback', + '"success": false', "'success': false") + return any(m in lower for m in error_markers) + + @staticmethod + def build_plan_update(steps: list[dict]) -> AgentPlanUpdate: + """Build an ACP plan update from a list of research steps. + + Each step dict should have ``description``, ``status`` (pending / + in_progress / completed), and optionally ``priority``. + """ + entries = [ + plan_entry( + content=s.get('description', ''), + status=s.get('status', 'pending'), + priority=s.get('priority', 'medium'), + ) for s in steps + ] + return update_plan(entries) + + @staticmethod + def map_stop_reason( + session, + cancelled: bool = False, + ) -> str: + """Map ms-agent runtime state to an ACP stop reason literal.""" + if cancelled or (hasattr(session, 'cancelled') and session.cancelled): + return 'cancelled' + + agent = session.agent + rt = getattr(agent, 'runtime', None) + if rt is None: + return 'end_turn' + + max_rounds = getattr(agent, 'max_chat_round', + getattr(agent, 'DEFAULT_MAX_CHAT_ROUND', 20)) + if rt.round >= max_rounds + 1: + return 'max_turn_requests' + + return 'end_turn' diff --git a/ms_agent/cli/a2a_cmd.py b/ms_agent/cli/a2a_cmd.py new file mode 100644 index 000000000..ebba08549 --- /dev/null +++ b/ms_agent/cli/a2a_cmd.py @@ -0,0 +1,203 @@ +import argparse +import os + +import json +from ms_agent.config.env import Env +from ms_agent.utils import strtobool + +from .base import CLICommand + + +def subparser_func(args): + return A2ACmd(args) + + +def registry_subparser_func(args): + return A2ARegistryCmd(args) + + +class A2ACmd(CLICommand): + """``ms-agent a2a`` -- start an A2A HTTP server.""" + + name = 'a2a' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: argparse.ArgumentParser): + parser: argparse.ArgumentParser = parsers.add_parser( + A2ACmd.name, + help='Start an A2A (Agent-to-Agent) protocol HTTP server', + ) + parser.add_argument( + '--config', + required=True, + type=str, + help='Path to the agent config YAML file (e.g. researcher.yaml)', + ) + parser.add_argument( + '--env', + required=False, + type=str, + default=None, + help='Path to a .env file', + ) + parser.add_argument( + '--trust_remote_code', + required=False, + type=str, + default='false', + help='Trust external code files referenced by the config', + ) + parser.add_argument( + '--host', + required=False, + type=str, + default='0.0.0.0', + help='Host to bind the A2A server (default: 0.0.0.0)', + ) + parser.add_argument( + '--port', + required=False, + type=int, + default=5000, + help='Port to bind the A2A server (default: 5000)', + ) + parser.add_argument( + '--max-tasks', + required=False, + type=int, + default=8, + help='Maximum concurrent A2A tasks (default: 8)', + ) + parser.add_argument( + '--task-timeout', + required=False, + type=int, + default=3600, + help='Task inactivity timeout in seconds (default: 3600)', + ) + parser.add_argument( + '--log-file', + required=False, + type=str, + default=None, + help='Write logs to this file instead of stderr', + ) + parser.set_defaults(func=subparser_func) + + def execute(self): + Env.load_dotenv_into_environ(getattr(self.args, 'env', None)) + + config_path = self.args.config + if not os.path.isabs(config_path): + config_path = os.path.abspath(config_path) + + trust_remote_code = strtobool(self.args.trust_remote_code) + + from ms_agent.a2a.executor import ( + MSAgentA2AExecutor, + configure_a2a_logging, + ) + from ms_agent.a2a.agent_card import build_agent_card + + configure_a2a_logging(self.args.log_file) + + agent_card = build_agent_card( + config_path=config_path, + host=self.args.host, + port=self.args.port, + ) + + executor = MSAgentA2AExecutor( + config_path=config_path, + trust_remote_code=trust_remote_code, + max_tasks=self.args.max_tasks, + task_timeout=self.args.task_timeout, + ) + + from a2a.server.apps import A2AStarletteApplication + from a2a.server.request_handlers import DefaultRequestHandler + from a2a.server.tasks import InMemoryTaskStore + + request_handler = DefaultRequestHandler( + agent_executor=executor, + task_store=InMemoryTaskStore(), + ) + + app = A2AStarletteApplication( + agent_card=agent_card, + http_handler=request_handler, + ) + + import uvicorn + uvicorn.run( + app.build(), + host=self.args.host, + port=self.args.port, + log_level='info', + ) + + +class A2ARegistryCmd(CLICommand): + """``ms-agent a2a-registry`` -- generate an A2A Agent Card JSON.""" + + name = 'a2a-registry' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: argparse.ArgumentParser): + parser: argparse.ArgumentParser = parsers.add_parser( + A2ARegistryCmd.name, + help='Generate an A2A Agent Card JSON for agent discovery', + ) + parser.add_argument( + '--config', + required=False, + type=str, + default=None, + help='Path to agent config YAML (used for metadata extraction)', + ) + parser.add_argument( + '--output', + required=False, + type=str, + default='agent-card.json', + help='Output path for the agent card (default: agent-card.json)', + ) + parser.add_argument( + '--host', + required=False, + type=str, + default='0.0.0.0', + help='Host the agent will be served on (default: 0.0.0.0)', + ) + parser.add_argument( + '--port', + required=False, + type=int, + default=5000, + help='Port the agent will be served on (default: 5000)', + ) + parser.add_argument( + '--title', + required=False, + type=str, + default='MS-Agent', + help='Agent display title in the card', + ) + parser.set_defaults(func=registry_subparser_func) + + def execute(self): + from ms_agent.a2a.agent_card import generate_agent_card_json + card = generate_agent_card_json( + config_path=self.args.config, + output_path=self.args.output, + host=self.args.host, + port=self.args.port, + title=self.args.title, + ) + print(json.dumps(card, indent=2)) diff --git a/ms_agent/cli/acp_cmd.py b/ms_agent/cli/acp_cmd.py new file mode 100644 index 000000000..0302a5a77 --- /dev/null +++ b/ms_agent/cli/acp_cmd.py @@ -0,0 +1,198 @@ +import argparse +import os + +import json +from ms_agent.config.env import Env +from ms_agent.utils import strtobool + +from .base import CLICommand + + +def subparser_func(args): + return ACPCmd(args) + + +def registry_subparser_func(args): + return ACPRegistryCmd(args) + + +class ACPCmd(CLICommand): + name = 'acp' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: argparse.ArgumentParser): + parser: argparse.ArgumentParser = parsers.add_parser( + ACPCmd.name, + help='Start an ACP (Agent Client Protocol) server over stdio', + ) + parser.add_argument( + '--config', + required=True, + type=str, + help='Path to the agent config YAML file (e.g. researcher.yaml)', + ) + parser.add_argument( + '--env', + required=False, + type=str, + default=None, + help='Path to a .env file', + ) + parser.add_argument( + '--trust_remote_code', + required=False, + type=str, + default='false', + help='Trust external code files referenced by the config', + ) + parser.add_argument( + '--max_sessions', + required=False, + type=int, + default=8, + help='Maximum concurrent ACP sessions (default: 8)', + ) + parser.add_argument( + '--session_timeout', + required=False, + type=int, + default=3600, + help='Session inactivity timeout in seconds (default: 3600)', + ) + parser.add_argument( + '--log-file', + required=False, + type=str, + default=None, + help='Write logs to this file instead of stderr', + ) + parser.add_argument( + '--serve-http', + action='store_true', + default=False, + help='Start a non-standard HTTP/SSE service API instead of stdio', + ) + parser.add_argument( + '--host', + required=False, + type=str, + default='0.0.0.0', + help='HTTP host to bind (only with --serve-http, default: 0.0.0.0)', + ) + parser.add_argument( + '--port', + required=False, + type=int, + default=8080, + help='HTTP port to bind (only with --serve-http, default: 8080)', + ) + parser.add_argument( + '--api-key', + required=False, + type=str, + default=None, + help= + 'API key for HTTP authentication (or set MS_AGENT_ACP_API_KEY env)', + ) + parser.set_defaults(func=subparser_func) + + def execute(self): + Env.load_dotenv_into_environ(getattr(self.args, 'env', None)) + + config_path = self.args.config + if not os.path.isabs(config_path): + config_path = os.path.abspath(config_path) + + trust_remote_code = strtobool(self.args.trust_remote_code) + + if getattr(self.args, 'serve_http', False): + self._serve_http(config_path, trust_remote_code) + else: + from ms_agent.acp.server import serve + serve( + config_path=config_path, + trust_remote_code=trust_remote_code, + max_sessions=self.args.max_sessions, + session_timeout=self.args.session_timeout, + log_file=self.args.log_file, + ) + + def _serve_http(self, config_path: str, trust_remote_code: bool): + """Start the non-standard HTTP/SSE internal service API.""" + import uvicorn + from fastapi import FastAPI + + from ms_agent.acp.http_adapter import configure_http_adapter + + app = FastAPI( + title='MS-Agent ACP Internal API', + description=( + 'Non-standard HTTP/SSE service API for ms-agent ACP server. ' + 'This is NOT an ACP-standard transport.'), + ) + acp_router = configure_http_adapter( + config_path=config_path, + trust_remote_code=trust_remote_code, + max_sessions=self.args.max_sessions, + session_timeout=self.args.session_timeout, + api_key=self.args.api_key, + ) + app.include_router(acp_router) + + uvicorn.run( + app, + host=self.args.host, + port=self.args.port, + log_level='info', + ) + + +class ACPRegistryCmd(CLICommand): + """``ms-agent acp-registry`` -- generate an ``agent.json`` manifest.""" + + name = 'acp-registry' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: argparse.ArgumentParser): + parser: argparse.ArgumentParser = parsers.add_parser( + ACPRegistryCmd.name, + help='Generate an agent.json manifest for ACP Agent Registry', + ) + parser.add_argument( + '--config', + required=False, + type=str, + default=None, + help= + 'Path to agent config YAML (baked into manifest transport args)', + ) + parser.add_argument( + '--output', + required=False, + type=str, + default='agent.json', + help='Output path for the manifest (default: agent.json)', + ) + parser.add_argument( + '--title', + required=False, + type=str, + default='MS-Agent', + help='Agent display title in the manifest', + ) + parser.set_defaults(func=registry_subparser_func) + + def execute(self): + from ms_agent.acp.registry import generate_agent_manifest + manifest = generate_agent_manifest( + config_path=self.args.config, + output_path=self.args.output, + title=self.args.title, + ) + print(json.dumps(manifest, indent=2)) diff --git a/ms_agent/cli/acp_proxy_cmd.py b/ms_agent/cli/acp_proxy_cmd.py new file mode 100644 index 000000000..2ba61c8d3 --- /dev/null +++ b/ms_agent/cli/acp_proxy_cmd.py @@ -0,0 +1,50 @@ +import argparse +import os + +from .base import CLICommand + + +def _subparser_func(args): + return ACPProxyCmd(args) + + +class ACPProxyCmd(CLICommand): + """``ms-agent acp-proxy`` -- start an ACP proxy that dispatches to + multiple backend agents.""" + + name = 'acp-proxy' + + def __init__(self, args): + self.args = args + + @staticmethod + def define_args(parsers: argparse.ArgumentParser): + parser: argparse.ArgumentParser = parsers.add_parser( + ACPProxyCmd.name, + help='Start an ACP proxy that routes to multiple backend agents', + ) + parser.add_argument( + '--config', + required=True, + type=str, + help='Path to the proxy config YAML (defines backends)', + ) + parser.add_argument( + '--log-file', + required=False, + type=str, + default=None, + help='Write logs to this file instead of stderr', + ) + parser.set_defaults(func=_subparser_func) + + def execute(self): + config_path = self.args.config + if not os.path.isabs(config_path): + config_path = os.path.abspath(config_path) + + from ms_agent.acp.proxy import serve_proxy + serve_proxy( + config_path=config_path, + log_file=self.args.log_file, + ) diff --git a/ms_agent/cli/cli.py b/ms_agent/cli/cli.py index da709e98d..d28ceb907 100644 --- a/ms_agent/cli/cli.py +++ b/ms_agent/cli/cli.py @@ -1,5 +1,8 @@ import argparse +from ms_agent.cli.a2a_cmd import A2ACmd, A2ARegistryCmd +from ms_agent.cli.acp_cmd import ACPCmd, ACPRegistryCmd +from ms_agent.cli.acp_proxy_cmd import ACPProxyCmd from ms_agent.cli.app import AppCMD from ms_agent.cli.run import RunCMD from ms_agent.cli.ui import UICMD @@ -17,6 +20,11 @@ def run_cmd(): subparsers = parser.add_subparsers( help='ModelScope-agent commands helpers') + A2ACmd.define_args(subparsers) + A2ARegistryCmd.define_args(subparsers) + ACPCmd.define_args(subparsers) + ACPProxyCmd.define_args(subparsers) + ACPRegistryCmd.define_args(subparsers) RunCMD.define_args(subparsers) AppCMD.define_args(subparsers) UICMD.define_args(subparsers) diff --git a/ms_agent/tools/a2a_agent_tool.py b/ms_agent/tools/a2a_agent_tool.py new file mode 100644 index 000000000..5dd6270e7 --- /dev/null +++ b/ms_agent/tools/a2a_agent_tool.py @@ -0,0 +1,78 @@ +from typing import Any, Dict, List + +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +A2A_TOOL_PREFIX = 'a2a' + + +class A2AAgentTool(ToolBase): + """A ``ToolBase`` that wraps all configured remote A2A agents. + + Each agent becomes a separate tool entry (``a2a_---``) + with its own description, so the LLM can select the right agent + based on capability descriptions -- mirroring how ``ACPAgentTool`` + exposes ``acp_---`` entries. + """ + + def __init__(self, config, a2a_agents_config: dict | None = None): + super().__init__(config) + self._a2a_config: dict = a2a_agents_config or {} + from ms_agent.a2a.client import A2AClientManager + self._client_manager = A2AClientManager(self._a2a_config) + + @classmethod + def from_config(cls, config) -> 'A2AAgentTool | None': + """Create an ``A2AAgentTool`` if the config has ``a2a_agents``.""" + if not hasattr(config, 'a2a_agents'): + return None + from omegaconf import OmegaConf + raw = OmegaConf.to_container(config.a2a_agents, resolve=True) + if not raw: + return None + return cls(config, a2a_agents_config=raw) + + async def connect(self) -> None: + pass + + async def cleanup(self) -> None: + await self._client_manager.close_all() + + async def _get_tools_inner(self) -> Dict[str, Any]: + tools: Dict[str, List[Tool]] = {} + for agent_name, agent_cfg in self._a2a_config.items(): + server_name = f'{A2A_TOOL_PREFIX}_{agent_name}' + tool_entry: Tool = { + 'tool_name': + agent_name, + 'description': + agent_cfg.get('description', + f'A2A remote agent: {agent_name}'), + 'parameters': { + 'type': 'object', + 'properties': { + 'query': { + 'type': + 'string', + 'description': + 'The task or query to send to this remote agent', + }, + }, + 'required': ['query'], + }, + } + tools[server_name] = [tool_entry] + return tools + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: + agent_name = server_name.replace(f'{A2A_TOOL_PREFIX}_', '', 1) + query = tool_args.get('query', '') + if not query: + return 'Error: "query" parameter is required' + logger.info('Calling A2A agent %s with query: %s', agent_name, + query[:200]) + return await self._client_manager.call_agent(agent_name, query) diff --git a/ms_agent/tools/acp_agent_tool.py b/ms_agent/tools/acp_agent_tool.py new file mode 100644 index 000000000..cd77e82c9 --- /dev/null +++ b/ms_agent/tools/acp_agent_tool.py @@ -0,0 +1,79 @@ +from typing import Any, Dict, List + +from ms_agent.acp.client import ACPClientManager +from ms_agent.llm.utils import Tool +from ms_agent.tools.base import ToolBase +from ms_agent.utils.logger import get_logger + +logger = get_logger() + +ACP_TOOL_PREFIX = 'acp' + + +class ACPAgentTool(ToolBase): + """A ``ToolBase`` that wraps all configured external ACP agents. + + Each agent becomes a separate tool entry (``acp---``) with its + own description, so the LLM can select the right agent based on + capability descriptions -- mirroring how ``MCPClient`` exposes + ``server---tool`` entries. + """ + + def __init__(self, config, acp_agents_config: dict | None = None): + super().__init__(config) + self._acp_config: dict = acp_agents_config or {} + self._client_manager = ACPClientManager(self._acp_config) + self._cwd: str = getattr(config, 'output_dir', '/tmp') + + @classmethod + def from_config(cls, config) -> 'ACPAgentTool | None': + """Create an ``ACPAgentTool`` if the config has ``acp_agents``.""" + if not hasattr(config, 'acp_agents'): + return None + from omegaconf import OmegaConf + raw = OmegaConf.to_container(config.acp_agents, resolve=True) + if not raw: + return None + return cls(config, acp_agents_config=raw) + + async def connect(self) -> None: + pass + + async def cleanup(self) -> None: + await self._client_manager.close_all() + + async def _get_tools_inner(self) -> Dict[str, Any]: + tools: Dict[str, List[Tool]] = {} + for agent_name, agent_cfg in self._acp_config.items(): + server_name = f'{ACP_TOOL_PREFIX}_{agent_name}' + tool_entry: Tool = { + 'tool_name': + agent_name, + 'description': + agent_cfg.get('description', f'ACP agent: {agent_name}'), + 'parameters': { + 'type': 'object', + 'properties': { + 'query': { + 'type': + 'string', + 'description': + 'The task or query to send to this agent', + }, + }, + 'required': ['query'], + }, + } + tools[server_name] = [tool_entry] + return tools + + async def call_tool(self, server_name: str, *, tool_name: str, + tool_args: dict) -> str: + agent_name = server_name.replace(f'{ACP_TOOL_PREFIX}_', '', 1) + query = tool_args.get('query', '') + if not query: + return 'Error: "query" parameter is required' + logger.info('Calling ACP agent %s with query: %s', agent_name, + query[:200]) + return await self._client_manager.call_agent( + agent_name, query, cwd=self._cwd) diff --git a/ms_agent/tools/tool_manager.py b/ms_agent/tools/tool_manager.py index 58f019774..2d5a77c6d 100644 --- a/ms_agent/tools/tool_manager.py +++ b/ms_agent/tools/tool_manager.py @@ -11,6 +11,7 @@ import json from ms_agent.llm.utils import Tool, ToolCall +from ms_agent.tools.acp_agent_tool import ACPAgentTool from ms_agent.tools.agent_tool import AgentTool from ms_agent.tools.base import ToolBase from ms_agent.tools.code import CodeExecutionTool, LocalCodeExecutionTool @@ -88,6 +89,17 @@ def __init__(self, self.extra_tools.append(TodoListTool(config)) if hasattr(config, 'tools') and hasattr(config.tools, 'web_search'): self.extra_tools.append(WebSearchTool(config)) + acp_tool = ACPAgentTool.from_config(config) + if acp_tool is not None: + self.extra_tools.append(acp_tool) + try: + from ms_agent.tools.a2a_agent_tool import A2AAgentTool + a2a_tool = A2AAgentTool.from_config(config) + if a2a_tool is not None: + self.extra_tools.append(a2a_tool) + except ImportError: + pass + self.tool_call_timeout = getattr(config, 'tool_call_timeout', TOOL_CALL_TIMEOUT) local_dir = self.config.local_dir if hasattr(self.config, diff --git a/requirements/a2a.txt b/requirements/a2a.txt new file mode 100644 index 000000000..cf506db9c --- /dev/null +++ b/requirements/a2a.txt @@ -0,0 +1 @@ +a2a-sdk[http-server]>=0.3.25,<1.0.0 diff --git a/requirements/acp.txt b/requirements/acp.txt new file mode 100644 index 000000000..aed5e9b7f --- /dev/null +++ b/requirements/acp.txt @@ -0,0 +1 @@ +agent-client-protocol>=0.9.0 diff --git a/setup.py b/setup.py index 1bded7e04..1da94eec5 100644 --- a/setup.py +++ b/setup.py @@ -238,6 +238,8 @@ def _build_and_copy_webui(self): 'requirements/research.txt') extra_requires['code'], _ = parse_requirements('requirements/code.txt') extra_requires['webui'], _ = parse_requirements('requirements/webui.txt') + extra_requires['acp'], _ = parse_requirements('requirements/acp.txt') + extra_requires['a2a'], _ = parse_requirements('requirements/a2a.txt') all_requires.extend(install_requires) all_requires.extend(extra_requires['research']) all_requires.extend(extra_requires['code']) diff --git a/tests/test_a2a/__init__.py b/tests/test_a2a/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_a2a/test_a2a_e2e.py b/tests/test_a2a/test_a2a_e2e.py new file mode 100644 index 000000000..3c5402793 --- /dev/null +++ b/tests/test_a2a/test_a2a_e2e.py @@ -0,0 +1,104 @@ +"""End-to-end A2A test using the A2A SDK client. + +This validates the full A2A lifecycle by starting an A2A server subprocess +and connecting to it with the A2A SDK client. + +**Requires** a valid agent config and a2a-sdk to be installed. +When those are unavailable the test is skipped gracefully. +""" + +import asyncio +import os +import subprocess +import sys +import time + +import pytest + +_SKIP_REASON = None +try: + import httpx + from a2a.client import A2ACardResolver, ClientConfig, ClientFactory + from a2a.client.helpers import create_text_message_object +except ImportError: + _SKIP_REASON = 'a2a-sdk or httpx not installed' + +_REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +_DEFAULT_CONFIG = os.path.join(_REPO_ROOT, 'ms_agent', 'agent', 'agent.yaml') +_A2A_TEST_CONFIG = os.environ.get('A2A_TEST_CONFIG', _DEFAULT_CONFIG) +_A2A_TEST_PORT = int(os.environ.get('A2A_TEST_PORT', '19999')) + + +def _have_config() -> bool: + return os.path.isfile(_A2A_TEST_CONFIG) + + +@pytest.fixture(scope='module') +def a2a_server(): + """Start an A2A server as a subprocess and yield its base URL.""" + if not _have_config(): + pytest.skip('No agent config found for A2A E2E test') + + proc = subprocess.Popen( + [sys.executable, '-m', 'ms_agent.cli.cli', + 'a2a', + '--config', _A2A_TEST_CONFIG, + '--host', '127.0.0.1', + '--port', str(_A2A_TEST_PORT)], + stdout=subprocess.PIPE, + stderr=subprocess.PIPE, + ) + + time.sleep(5) + + if proc.poll() is not None: + stderr = proc.stderr.read().decode() if proc.stderr else '' + pytest.skip(f'A2A server failed to start: {stderr[:500]}') + + yield f'http://127.0.0.1:{_A2A_TEST_PORT}' + + proc.terminate() + try: + proc.wait(timeout=5) + except subprocess.TimeoutExpired: + proc.kill() + + +@pytest.mark.skipif(_SKIP_REASON is not None, reason=_SKIP_REASON or '') +@pytest.mark.skipif(not _have_config(), + reason='No agent config found for A2A E2E test') +class TestA2AE2E: + + @pytest.mark.asyncio + async def test_discover_agent_card(self, a2a_server): + """Verify the A2A server publishes an Agent Card.""" + async with httpx.AsyncClient() as http: + resolver = A2ACardResolver( + httpx_client=http, base_url=a2a_server) + card = await resolver.get_agent_card() + assert card.name + assert card.capabilities is not None + assert len(card.skills) >= 1 + + @pytest.mark.asyncio + async def test_send_message(self, a2a_server): + """Send a simple message and verify we get a response.""" + if not os.environ.get('OPENAI_API_KEY'): + pytest.skip('No OPENAI_API_KEY for LLM-backed test') + + async with httpx.AsyncClient(timeout=120.0) as http: + resolver = A2ACardResolver( + httpx_client=http, base_url=a2a_server) + card = await resolver.get_agent_card() + + factory = ClientFactory( + config=ClientConfig(httpx_client=http)) + client = factory.create(card) + + message = create_text_message_object(content='Say hello in 3 words') + events = [] + async for event in client.send_message(message): + events.append(event) + + assert len(events) >= 1 diff --git a/tests/test_a2a/test_a2a_protocol.py b/tests/test_a2a/test_a2a_protocol.py new file mode 100644 index 000000000..ceabeff31 --- /dev/null +++ b/tests/test_a2a/test_a2a_protocol.py @@ -0,0 +1,175 @@ +"""Protocol-level tests for A2A components. + +Tests the full A2A protocol flow using mock agents and the A2A SDK types. +Skipped if a2a-sdk is not installed. +""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +_SKIP_REASON = None +try: + from a2a.types import ( + AgentCapabilities, + AgentCard, + AgentSkill, + Part, + TaskState, + TextPart, + ) + from a2a.server.agent_execution import AgentExecutor, RequestContext + from a2a.server.events import EventQueue + from a2a.server.tasks import TaskUpdater, InMemoryTaskStore + from a2a.utils import new_agent_text_message, new_task +except ImportError: + _SKIP_REASON = 'a2a-sdk not installed' + + +pytestmark = pytest.mark.skipif( + _SKIP_REASON is not None, reason=_SKIP_REASON or '') + + +# ====================================================================== +# Agent Card tests +# ====================================================================== + +class TestAgentCard: + + def test_build_agent_card_defaults(self): + from ms_agent.a2a.agent_card import build_agent_card + card = build_agent_card() + assert card.name == 'ms-agent' + assert card.url == 'http://localhost:5000/' + assert card.capabilities.streaming is True + assert len(card.skills) >= 1 + + def test_build_agent_card_custom_host_port(self): + from ms_agent.a2a.agent_card import build_agent_card + card = build_agent_card(host='myhost', port=8080) + assert card.url == 'http://myhost:8080/' + + def test_build_agent_card_with_skills(self): + from ms_agent.a2a.agent_card import build_agent_card + skills = [ + {'id': 'research', 'name': 'Deep Research', + 'description': 'Research topics'}, + ] + card = build_agent_card(skills=skills) + assert len(card.skills) == 1 + assert card.skills[0].id == 'research' + + def test_generate_agent_card_json(self, tmp_path): + from ms_agent.a2a.agent_card import generate_agent_card_json + import json + out = tmp_path / 'card.json' + card_dict = generate_agent_card_json(output_path=str(out)) + assert out.exists() + with open(out) as f: + data = json.load(f) + assert data['name'] == 'ms-agent' + assert 'capabilities' in data + + +# ====================================================================== +# Executor tests with mock agent +# ====================================================================== + +class TestExecutor: + + @pytest.mark.asyncio + async def test_executor_cancel_unknown_task(self): + from ms_agent.a2a.executor import MSAgentA2AExecutor + executor = MSAgentA2AExecutor( + config_path='/tmp/nonexistent.yaml', + max_tasks=2, + ) + event_queue = EventQueue() + + context = RequestContext( + task_id='task_unknown', + context_id='ctx_1', + ) + await executor.cancel(context, event_queue) + + @pytest.mark.asyncio + async def test_executor_cleanup(self): + from ms_agent.a2a.executor import MSAgentA2AExecutor + executor = MSAgentA2AExecutor( + config_path='/tmp/nonexistent.yaml', + ) + await executor.cleanup() + + +# ====================================================================== +# TaskUpdater integration tests +# ====================================================================== + +class TestTaskUpdater: + + @staticmethod + async def _drain_queue(event_queue: EventQueue) -> list: + """Drain all events from the queue without blocking.""" + events = [] + while True: + try: + event = await event_queue.dequeue_event(no_wait=True) + events.append(event) + event_queue.task_done() + except (asyncio.QueueEmpty, Exception): + break + return events + + @pytest.mark.asyncio + async def test_updater_lifecycle(self): + """Test the basic submit -> working -> complete lifecycle.""" + event_queue = EventQueue() + updater = TaskUpdater(event_queue, 'task_1', 'ctx_1') + + await updater.submit() + await updater.start_work() + await updater.add_artifact( + [Part(root=TextPart(text='result'))], + name='response', + ) + await updater.complete() + + events = await self._drain_queue(event_queue) + assert len(events) >= 3 + + @pytest.mark.asyncio + async def test_updater_failed(self): + event_queue = EventQueue() + updater = TaskUpdater(event_queue, 'task_2', 'ctx_2') + + await updater.submit() + await updater.start_work() + await updater.failed( + new_agent_text_message('something broke', 'ctx_2', 'task_2')) + + events = await self._drain_queue(event_queue) + assert len(events) >= 3 + + @pytest.mark.asyncio + async def test_updater_cancel(self): + event_queue = EventQueue() + updater = TaskUpdater(event_queue, 'task_3', 'ctx_3') + + await updater.submit() + await updater.cancel() + + events = await self._drain_queue(event_queue) + assert len(events) >= 2 + + +# ====================================================================== +# InMemoryTaskStore tests +# ====================================================================== + +class TestTaskStore: + + @pytest.mark.asyncio + async def test_in_memory_task_store_get_none(self): + store = InMemoryTaskStore() + result = await store.get('nonexistent') + assert result is None diff --git a/tests/test_a2a/test_a2a_unit.py b/tests/test_a2a/test_a2a_unit.py new file mode 100644 index 000000000..fb723996b --- /dev/null +++ b/tests/test_a2a/test_a2a_unit.py @@ -0,0 +1,304 @@ +"""Unit tests for A2A components (no external processes or SDK needed).""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from ms_agent.a2a.errors import ( + A2AServerError, + AgentLoadError, + ConfigError, + LLMError, + MaxTasksError, + RateLimitError, + TaskNotFoundError, + wrap_a2a_error, +) +from ms_agent.a2a.translator import ( + a2a_message_to_ms_messages, + collect_full_response, + extract_text_from_a2a_message, + ms_messages_to_text, +) +from ms_agent.llm.utils import Message + + +# ====================================================================== +# Error mapping tests +# ====================================================================== + +class TestErrorMapping: + + def test_task_not_found(self): + err = TaskNotFoundError('task_abc') + assert err.code == -32001 + assert 'task_abc' in str(err.data) + + def test_agent_load_error(self): + err = AgentLoadError('config parse failure') + assert err.code == -32002 + assert 'config parse failure' in err.data['detail'] + + def test_llm_error(self): + err = LLMError('timeout') + assert err.code == -32003 + + def test_rate_limit_error(self): + err = RateLimitError('too many requests') + assert err.code == -32004 + + def test_config_error(self): + err = ConfigError('missing key') + assert err.code == -32005 + + def test_max_tasks_error(self): + err = MaxTasksError(8) + assert err.code == -32006 + assert err.data['max'] == 8 + + def test_wrap_a2a_server_error(self): + err = LLMError('timeout') + result = wrap_a2a_error(err) + assert result['code'] == -32003 + assert 'timeout' in result['data']['detail'] + + def test_wrap_file_not_found(self): + result = wrap_a2a_error(FileNotFoundError('/path')) + assert result['code'] == -32002 + + def test_wrap_value_error(self): + result = wrap_a2a_error(ValueError('bad param')) + assert result['code'] == -32602 + + def test_wrap_unknown_error(self): + result = wrap_a2a_error(RuntimeError('unexpected')) + assert result['code'] == -32603 + assert 'unexpected' in result['data']['detail'] + + def test_wrap_permission_error(self): + result = wrap_a2a_error(PermissionError('denied')) + assert result['code'] == -32000 + + def test_wrap_timeout_error(self): + result = wrap_a2a_error(TimeoutError('timed out')) + assert result['code'] == -32004 + + +# ====================================================================== +# Translator tests +# ====================================================================== + +class TestTranslator: + + def test_extract_text_from_text_part(self): + msg = MagicMock() + part = MagicMock() + part.root = MagicMock(type='text', text='Hello world') + msg.parts = [part] + assert extract_text_from_a2a_message(msg) == 'Hello world' + + def test_extract_text_multiple_parts(self): + msg = MagicMock() + p1 = MagicMock() + p1.root = MagicMock(type='text', text='Part A') + p2 = MagicMock() + p2.root = MagicMock(type='text', text='Part B') + msg.parts = [p1, p2] + result = extract_text_from_a2a_message(msg) + assert 'Part A' in result + assert 'Part B' in result + + def test_extract_text_file_part(self): + msg = MagicMock() + part = MagicMock() + file_obj = MagicMock(spec=['name', 'mimeType', 'uri']) + file_obj.name = 'test.txt' + file_obj.mimeType = 'text/plain' + file_obj.uri = 'file:///test.txt' + part.root = MagicMock(spec=['type', 'file']) + part.root.type = 'file' + part.root.file = file_obj + msg.parts = [part] + result = extract_text_from_a2a_message(msg) + assert 'test.txt' in result + + def test_extract_text_none_message(self): + assert extract_text_from_a2a_message(None) == '' + + def test_extract_text_no_parts(self): + msg = MagicMock() + msg.parts = None + result = extract_text_from_a2a_message(msg) + assert isinstance(result, str) + + def test_a2a_message_to_ms_messages(self): + msg = MagicMock() + part = MagicMock() + part.root = MagicMock(type='text', text='Hello') + msg.parts = [part] + result = a2a_message_to_ms_messages(msg) + assert len(result) == 1 + assert result[0].role == 'user' + assert result[0].content == 'Hello' + + def test_a2a_message_to_ms_messages_appends(self): + existing = [Message(role='system', content='You are helpful')] + msg = MagicMock() + part = MagicMock() + part.root = MagicMock(type='text', text='Query') + msg.parts = [part] + result = a2a_message_to_ms_messages(msg, existing) + assert len(result) == 2 + assert result is existing + + def test_ms_messages_to_text(self): + msgs = [ + Message(role='user', content='Hi'), + Message(role='assistant', content='Hello back!'), + ] + assert ms_messages_to_text(msgs) == 'Hello back!' + + def test_ms_messages_to_text_empty(self): + assert ms_messages_to_text([]) == '' + + def test_ms_messages_to_text_no_assistant(self): + msgs = [Message(role='user', content='Hi')] + assert ms_messages_to_text(msgs) == '' + + def test_collect_full_response(self): + msgs = [ + Message(role='user', content='Hi'), + Message(role='assistant', content='Part 1'), + Message(role='tool', content='tool output', tool_call_id='tc_1'), + Message(role='assistant', content='Part 2'), + ] + result = collect_full_response(msgs) + assert 'Part 1' in result + assert 'Part 2' in result + + def test_collect_full_response_empty(self): + assert collect_full_response([]) == '' + + +# ====================================================================== +# Session store tests +# ====================================================================== + +class TestSessionStore: + + @pytest.mark.asyncio + async def test_get_or_create_missing_config(self): + from ms_agent.a2a.session_store import A2AAgentStore + store = A2AAgentStore( + config_path='/nonexistent/config.yaml', + max_tasks=2, + ) + with pytest.raises(ConfigError): + await store.get_or_create('task_1') + + @pytest.mark.asyncio + async def test_get_returns_none_for_unknown(self): + from ms_agent.a2a.session_store import A2AAgentStore + store = A2AAgentStore(config_path='/tmp/fake.yaml') + assert store.get('unknown_task') is None + + @pytest.mark.asyncio + async def test_close_all_empty(self): + from ms_agent.a2a.session_store import A2AAgentStore + store = A2AAgentStore(config_path='/tmp/fake.yaml') + await store.close_all() + + +# ====================================================================== +# Client manager tests +# ====================================================================== + +class TestClientManager: + + @pytest.mark.asyncio + async def test_call_unknown_agent(self): + from ms_agent.a2a.client import A2AClientManager + mgr = A2AClientManager({}) + result = await mgr.call_agent('unknown', 'hi') + assert 'Error' in result + assert 'not configured' in result + + @pytest.mark.asyncio + async def test_call_agent_no_url(self): + from ms_agent.a2a.client import A2AClientManager + mgr = A2AClientManager({'test_agent': {'description': 'test'}}) + result = await mgr.call_agent('test_agent', 'hi') + assert 'Error' in result + assert 'no URL' in result + + @pytest.mark.asyncio + async def test_list_agents(self): + from ms_agent.a2a.client import A2AClientManager + mgr = A2AClientManager({ + 'agent_a': {'url': 'http://a'}, + 'agent_b': {'url': 'http://b'}, + }) + agents = mgr.list_agents() + assert 'agent_a' in agents + assert 'agent_b' in agents + + @pytest.mark.asyncio + async def test_close_all(self): + from ms_agent.a2a.client import A2AClientManager + mgr = A2AClientManager({}) + await mgr.close_all() + + def test_build_auth_headers_bearer(self): + from ms_agent.a2a.client import A2AClientManager + headers = A2AClientManager._build_auth_headers({ + 'auth': {'type': 'bearer', 'token': 'my_token'} + }) + assert headers['Authorization'] == 'Bearer my_token' + + def test_build_auth_headers_no_auth(self): + from ms_agent.a2a.client import A2AClientManager + headers = A2AClientManager._build_auth_headers({}) + assert headers == {} + + +# ====================================================================== +# Tool tests +# ====================================================================== + +class TestA2AAgentTool: + + def test_from_config_no_a2a_agents(self): + from ms_agent.tools.a2a_agent_tool import A2AAgentTool + config = MagicMock(spec=[]) + result = A2AAgentTool.from_config(config) + assert result is None + + @pytest.mark.asyncio + async def test_get_tools(self): + from ms_agent.tools.a2a_agent_tool import A2AAgentTool + config = MagicMock() + tool = A2AAgentTool(config, a2a_agents_config={ + 'my_agent': { + 'url': 'http://localhost:9999', + 'description': 'Test agent', + } + }) + tools = await tool.get_tools() + assert 'a2a_my_agent' in tools + assert len(tools['a2a_my_agent']) == 1 + assert tools['a2a_my_agent'][0]['tool_name'] == 'my_agent' + + @pytest.mark.asyncio + async def test_call_tool_missing_query(self): + from ms_agent.tools.a2a_agent_tool import A2AAgentTool + config = MagicMock() + tool = A2AAgentTool(config, a2a_agents_config={ + 'my_agent': { + 'url': 'http://localhost:9999', + 'description': 'Test agent', + } + }) + result = await tool.call_tool( + 'a2a_my_agent', tool_name='my_agent', tool_args={}) + assert 'Error' in result + assert 'query' in result diff --git a/tests/test_acp/__init__.py b/tests/test_acp/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/tests/test_acp/proxy_opencode.yaml b/tests/test_acp/proxy_opencode.yaml new file mode 100644 index 000000000..830581019 --- /dev/null +++ b/tests/test_acp/proxy_opencode.yaml @@ -0,0 +1,10 @@ +proxy: + max_sessions: 4 + session_timeout: 3600 + default_backend: opencode + +backends: + opencode: + command: opencode + args: [acp] + description: "OpenCode coding agent" diff --git a/tests/test_acp/test_acp_e2e.py b/tests/test_acp/test_acp_e2e.py new file mode 100644 index 000000000..97e44275e --- /dev/null +++ b/tests/test_acp/test_acp_e2e.py @@ -0,0 +1,74 @@ +"""End-to-end ACP test using ``spawn_agent_process``. + +This validates the full ACP lifecycle (initialize -> new_session -> prompt) +without depending on an external client like Zed. + +**Requires** a valid agent config and LLM API key in the environment to +actually run prompts. When those are unavailable the test is skipped +gracefully so CI stays green. +""" + +import asyncio +import os +import sys +from typing import Any + +import pytest + +_SKIP_REASON = None +try: + from acp import spawn_agent_process, text_block + from acp.interfaces import Client +except ImportError: + _SKIP_REASON = 'agent-client-protocol not installed' + +# Best-effort: find a usable agent config for the test. +_REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +_DEFAULT_CONFIG = os.path.join(_REPO_ROOT, 'ms_agent', 'agent', 'agent.yaml') +_ACP_TEST_CONFIG = os.environ.get('ACP_TEST_CONFIG', _DEFAULT_CONFIG) + + +def _have_config() -> bool: + return os.path.isfile(_ACP_TEST_CONFIG) + + +class _TestClient(Client): + """Minimal ACP client that records every session update.""" + + def __init__(self): + self.updates: list = [] + + async def session_update(self, session_id, update, **kwargs): + self.updates.append(update) + + async def request_permission(self, options, session_id, tool_call, + **kwargs): + allow = next((o for o in options if 'allow' in (o.kind or '')), None) + if allow: + return {'outcome': {'outcome': 'selected', 'id': allow.option_id}} + return {'outcome': {'outcome': 'cancelled'}} + + +@pytest.mark.skipif(not _have_config(), + reason='No agent config found for ACP E2E test') +@pytest.mark.skipif(_SKIP_REASON is not None, reason=_SKIP_REASON or '') +@pytest.mark.asyncio +async def test_acp_initialize_and_new_session(): + """Verify the server boots, negotiates protocol, and creates a session.""" + client = _TestClient() + async with spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp', + '--config', _ACP_TEST_CONFIG, + ) as (conn, _proc): + resp = await conn.initialize(protocol_version=1) + assert resp.protocol_version == 1 + assert resp.agent_info is not None + assert resp.agent_info.name == 'ms-agent' + + session = await conn.new_session(cwd='/tmp', mcp_servers=[]) + assert session.session_id + assert session.session_id.startswith('ses_') diff --git a/tests/test_acp/test_acp_e2e_real.py b/tests/test_acp/test_acp_e2e_real.py new file mode 100644 index 000000000..fbb575d34 --- /dev/null +++ b/tests/test_acp/test_acp_e2e_real.py @@ -0,0 +1,205 @@ +import asyncio +import os +import sys + +import pytest + +_REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +_TEST_CONFIG = os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'test_openai_config.yaml') + +sys.path.insert(0, _REPO_ROOT) + +_HAS_API_KEY = bool(os.environ.get('OPENAI_API_KEY')) + +try: + from acp import spawn_agent_process, text_block + from acp.interfaces import Client + _HAS_ACP = True +except ImportError: + _HAS_ACP = False + + +class _StreamingTestClient(Client): + """Records all session updates from the ACP server.""" + + def __init__(self): + self.updates = [] + self.text_chunks = [] + self.thought_chunks = [] + self.tool_calls = [] + + async def session_update(self, session_id, update, **kwargs): + self.updates.append(update) + update_type = getattr(update, 'session_update', None) + if update_type == 'agent_message_chunk': + content = getattr(update, 'content', None) + if content: + text = getattr(content, 'text', None) or str(content) + self.text_chunks.append(text) + elif update_type == 'agent_thought_chunk': + content = getattr(update, 'content', None) + if content: + text = getattr(content, 'text', None) or str(content) + self.thought_chunks.append(text) + elif update_type == 'tool_call_start': + self.tool_calls.append(update) + + async def request_permission(self, options, session_id, tool_call, **kwargs): + allow = next( + (o for o in options if 'allow' in (getattr(o, 'kind', '') or '')), + None, + ) + if allow: + return {'outcome': {'outcome': 'selected', 'id': getattr(allow, 'option_id', 'allow_once')}} + return {'outcome': {'outcome': 'cancelled'}} + + @property + def full_text(self): + return ''.join(self.text_chunks) + + +@pytest.mark.skipif(not _HAS_ACP, reason='agent-client-protocol not installed') +@pytest.mark.skipif(not _HAS_API_KEY, reason='OPENAI_API_KEY not set') +@pytest.mark.skipif(not os.path.isfile(_TEST_CONFIG), reason='Test config not found') +class TestACPE2ERealLLM: + + @pytest.mark.asyncio + async def test_full_prompt_with_real_llm(self): + """Complete flow: initialize -> new_session -> prompt with real LLM.""" + client = _StreamingTestClient() + + async with spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp', + '--config', _TEST_CONFIG, + ) as (conn, proc): + init_resp = await conn.initialize(protocol_version=1) + assert init_resp.protocol_version == 1 + assert init_resp.agent_info.name == 'ms-agent' + + session = await conn.new_session(cwd='/tmp', mcp_servers=[]) + sid = session.session_id + assert sid.startswith('ses_') + + prompt_resp = await conn.prompt( + session_id=sid, + prompt=[text_block('What is 2 + 3? Answer with just the number.')], + ) + + assert prompt_resp.stop_reason in ('end_turn', 'max_turn_requests') + + full_text = client.full_text + print(f'\n[LLM Response] "{full_text}"') + assert len(full_text) > 0, 'Expected non-empty response' + assert '5' in full_text, f'Expected "5" in response, got: {full_text}' + + @pytest.mark.asyncio + async def test_streaming_produces_multiple_chunks(self): + """Verify that the streaming produces incremental updates.""" + client = _StreamingTestClient() + + async with spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp', + '--config', _TEST_CONFIG, + ) as (conn, proc): + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd='/tmp', mcp_servers=[]) + + await conn.prompt( + session_id=sid if (sid := session.session_id) else '', + prompt=[text_block( + 'List the first 5 prime numbers, one per line.' + )], + ) + + assert len(client.updates) > 0, 'Expected at least one update' + print(f'\n[Streaming] Got {len(client.updates)} updates, ' + f'{len(client.text_chunks)} text chunks') + assert '2' in client.full_text + assert '3' in client.full_text + assert '5' in client.full_text + + @pytest.mark.asyncio + async def test_multi_turn_conversation(self): + """Verify multi-turn works: send two prompts in the same session.""" + client = _StreamingTestClient() + + async with spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp', + '--config', _TEST_CONFIG, + ) as (conn, proc): + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd='/tmp', mcp_servers=[]) + sid = session.session_id + + await conn.prompt( + session_id=sid, + prompt=[text_block('Remember the number 42.')], + ) + turn1_text = client.full_text + print(f'\n[Turn 1] "{turn1_text}"') + + client.text_chunks.clear() + client.updates.clear() + + await conn.prompt( + session_id=sid, + prompt=[text_block('What number did I just ask you to remember?')], + ) + turn2_text = client.full_text + print(f'[Turn 2] "{turn2_text}"') + assert '42' in turn2_text, f'Expected "42" in turn 2 response: {turn2_text}' + + @pytest.mark.asyncio + async def test_config_options_returned(self): + """Verify new_session returns config options with model selector.""" + client = _StreamingTestClient() + + async with spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp', + '--config', _TEST_CONFIG, + ) as (conn, proc): + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd='/tmp', mcp_servers=[]) + + assert session.config_options is not None + model_opt = next( + (o for o in session.config_options if o.id == 'model'), None) + assert model_opt is not None, 'Expected model config option' + assert model_opt.current_value == 'qwen-plus' + print(f'\n[Config] Model option: {model_opt.current_value}') + + @pytest.mark.asyncio + async def test_session_modes_returned(self): + """Verify new_session returns session modes.""" + client = _StreamingTestClient() + + async with spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp', + '--config', _TEST_CONFIG, + ) as (conn, proc): + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd='/tmp', mcp_servers=[]) + + assert session.modes is not None + assert session.modes.current_mode_id == 'agent' + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short', '-s']) diff --git a/tests/test_acp/test_acp_protocol.py b/tests/test_acp/test_acp_protocol.py new file mode 100644 index 000000000..ecbf8507c --- /dev/null +++ b/tests/test_acp/test_acp_protocol.py @@ -0,0 +1,655 @@ +"""Comprehensive ACP protocol correctness and rendering tests. + +Tests cover: + 1. Translator rendering: edit/read/execute tool calls produce correct ACP + content types (diff, locations, specialized starts) + 2. Multi-message delta tracking: all messages (including tool results) + are captured, not just the last one + 3. Plan update extraction from todo_write results + 4. Permission schema correctness + 5. Full message flow simulation (assistant -> tool_call -> tool_result) + 6. Error detection in tool results +""" + +import asyncio +import json +import uuid +from unittest.mock import AsyncMock, MagicMock + +import pytest + +from ms_agent.acp.translator import ( + ACPTranslator, + _build_title, + _extract_locations, + _infer_kind_from_args, + _parse_tool_args, + _try_parse_diff_from_result, +) +from ms_agent.llm.utils import Message + + +class TestToolArgParsing: + + def test_parse_valid_json(self): + args = _parse_tool_args('{"path": "/src/main.py", "content": "hello"}') + assert args['path'] == '/src/main.py' + assert args['content'] == 'hello' + + def test_parse_empty(self): + assert _parse_tool_args('') == {} + assert _parse_tool_args(None) == {} + + def test_parse_invalid_json(self): + assert _parse_tool_args('not json') == {} + + def test_parse_non_dict(self): + assert _parse_tool_args('"just a string"') == {} + + +class TestKindInference: + + def test_known_tools(self): + assert _infer_kind_from_args('code_executor', {}) == 'execute' + assert _infer_kind_from_args('web_search', {}) == 'search' + assert _infer_kind_from_args('file_write', {}) == 'edit' + assert _infer_kind_from_args('todo', {}) == 'think' + + def test_filesystem_method_dispatch(self): + assert _infer_kind_from_args( + 'filesystem', {'method': 'write_file'}) == 'edit' + assert _infer_kind_from_args( + 'filesystem', {'method': 'read_file'}) == 'read' + assert _infer_kind_from_args( + 'filesystem', {'method': 'list_files'}) == 'read' + assert _infer_kind_from_args( + 'filesystem', {'method': 'replace_file_contents'}) == 'edit' + + def test_name_heuristic(self): + assert _infer_kind_from_args('my_custom_write_tool', {}) == 'edit' + assert _infer_kind_from_args('grep_search_tool', {}) == 'search' + assert _infer_kind_from_args('run_shell_command', {}) == 'execute' + assert _infer_kind_from_args('delete_file', {}) == 'delete' + assert _infer_kind_from_args('rename_file', {}) == 'move' + assert _infer_kind_from_args('fetch_url', {}) == 'fetch' + assert _infer_kind_from_args('plan_steps', {}) == 'think' + + def test_unknown_tool(self): + assert _infer_kind_from_args('totally_unknown_xyz', {}) == 'other' + + +class TestTitleBuilding: + + def test_edit_with_path(self): + title = _build_title('file_write', {'path': '/src/main.py'}, 'edit') + assert title == 'Edit /src/main.py' + + def test_read_with_path(self): + title = _build_title('file_read', {'path': '/src/config.json'}, 'read') + assert title == 'Read /src/config.json' + + def test_execute_with_code(self): + title = _build_title('code_executor', + {'code': 'print("hello world")\nprint("done")'}, 'execute') + assert 'print("hello world")' in title + assert title.startswith('Run:') + + def test_search_with_query(self): + title = _build_title('web_search', {'query': 'quantum computing'}, 'search') + assert 'quantum computing' in title + + def test_fallback_to_tool_name(self): + title = _build_title('my_custom_tool', {}, 'other') + assert title == 'my_custom_tool' + + +class TestLocationExtraction: + + def test_extract_from_path(self): + locs = _extract_locations({'path': '/src/main.py'}) + assert locs is not None + assert len(locs) == 1 + assert locs[0].path == '/src/main.py' + assert locs[0].line is None + + def test_extract_with_line(self): + locs = _extract_locations({'path': '/src/main.py', 'line': 42}) + assert locs[0].line == 42 + + def test_extract_from_file_path(self): + locs = _extract_locations({'file_path': '/src/utils.py'}) + assert locs[0].path == '/src/utils.py' + + def test_no_path(self): + assert _extract_locations({}) is None + assert _extract_locations({'query': 'test'}) is None + + +class TestDiffExtraction: + + def test_write_file(self): + diff = _try_parse_diff_from_result( + 'file_write', + {'path': '/src/main.py', 'content': 'def hello(): pass'}, + 'Save file successfully', + ) + assert diff is not None + assert diff['path'] == '/src/main.py' + assert diff['new_text'] == 'def hello(): pass' + assert diff['old_text'] is None + + def test_replace_file_contents(self): + diff = _try_parse_diff_from_result( + 'replace_file_contents', + {'path': '/src/main.py', 'source': 'old code', 'target': 'new code'}, + 'Replaced successfully', + ) + assert diff is not None + assert diff['old_text'] == 'old code' + assert diff['new_text'] == 'new code' + + def test_file_operation_write(self): + diff = _try_parse_diff_from_result( + 'filesystem', + {'path': '/data.txt', 'method': 'write', 'content': 'data'}, + 'OK', + ) + assert diff is not None + assert diff['new_text'] == 'data' + + def test_no_path_returns_none(self): + diff = _try_parse_diff_from_result( + 'file_write', {}, 'OK') + assert diff is None + + def test_no_matching_method_returns_none(self): + diff = _try_parse_diff_from_result( + 'file_read', {'path': '/src/main.py', 'method': 'read'}, 'content') + assert diff is None + + +class TestTranslatorEdits: + """Test that file edit tool calls produce proper ACP start_edit_tool_call.""" + + def test_file_write_produces_edit_start(self): + t = ACPTranslator() + tc = { + 'id': 'tc_write_1', + 'tool_name': 'file_write', + 'arguments': json.dumps({ + 'path': '/src/main.py', + 'content': 'def hello(): pass', + }), + } + msgs = [Message(role='assistant', content='', tool_calls=[tc])] + updates = t.messages_to_updates(msgs) + + tool_starts = [u for u in updates if hasattr(u, 'tool_call_id') + and hasattr(u, 'session_update') + and u.session_update == 'tool_call'] + assert len(tool_starts) == 1 + start = tool_starts[0] + assert start.tool_call_id == 'tc_write_1' + assert start.kind == 'edit' + assert 'Edit' in start.title + + def test_file_read_produces_read_start(self): + t = ACPTranslator() + tc = { + 'id': 'tc_read_1', + 'tool_name': 'file_read', + 'arguments': json.dumps({'path': '/src/config.json'}), + } + msgs = [Message(role='assistant', content='', tool_calls=[tc])] + updates = t.messages_to_updates(msgs) + + tool_starts = [u for u in updates if hasattr(u, 'tool_call_id') + and getattr(u, 'session_update', '') == 'tool_call'] + assert len(tool_starts) == 1 + start = tool_starts[0] + assert start.kind == 'read' + assert 'Read' in start.title + + def test_code_execute_produces_execute_start(self): + t = ACPTranslator() + tc = { + 'id': 'tc_exec_1', + 'tool_name': 'code_executor', + 'arguments': json.dumps({'code': 'print("hello")'}), + } + msgs = [Message(role='assistant', content='', tool_calls=[tc])] + updates = t.messages_to_updates(msgs) + + tool_starts = [u for u in updates + if getattr(u, 'session_update', '') == 'tool_call'] + assert len(tool_starts) == 1 + start = tool_starts[0] + assert start.kind == 'execute' + assert 'Run:' in start.title or 'Execute' in start.title + + def test_search_tool_produces_search_start(self): + t = ACPTranslator() + tc = { + 'id': 'tc_search_1', + 'tool_name': 'web_search', + 'arguments': json.dumps({'query': 'quantum computing'}), + } + msgs = [Message(role='assistant', content='', tool_calls=[tc])] + updates = t.messages_to_updates(msgs) + + tool_starts = [u for u in updates + if getattr(u, 'session_update', '') == 'tool_call'] + assert len(tool_starts) == 1 + start = tool_starts[0] + assert start.kind == 'search' + assert 'quantum computing' in start.title + + +class TestTranslatorToolResults: + """Test that tool results produce correct content types.""" + + def test_file_write_result_with_diff(self): + t = ACPTranslator() + tc = { + 'id': 'tc_w1', + 'tool_name': 'file_write', + 'arguments': json.dumps({ + 'path': '/src/main.py', 'content': 'def hello(): pass' + }), + } + msgs_assistant = [Message(role='assistant', content='', tool_calls=[tc])] + t.messages_to_updates(msgs_assistant) + + msgs_tool = [ + Message(role='assistant', content='', tool_calls=[tc]), + Message(role='tool', content='Save file successfully', + tool_call_id='tc_w1', name='file_write'), + ] + updates = t.messages_to_updates(msgs_tool) + + tool_updates = [u for u in updates + if getattr(u, 'session_update', '') == 'tool_call_update'] + assert len(tool_updates) == 1 + tu = tool_updates[0] + assert tu.status == 'completed' + assert tu.content is not None + has_diff = any( + getattr(c, 'type', '') == 'diff' for c in tu.content) + assert has_diff, 'File write should produce a diff content item' + + def test_error_result_sets_failed_status(self): + t = ACPTranslator() + tc = { + 'id': 'tc_err', + 'tool_name': 'code_executor', + 'arguments': '{"code": "raise Exception()"}', + } + msgs = [Message(role='assistant', content='', tool_calls=[tc])] + t.messages_to_updates(msgs) + + msgs2 = [ + Message(role='assistant', content='', tool_calls=[tc]), + Message(role='tool', + content='Error: Traceback (most recent call last)...', + tool_call_id='tc_err', name='code_executor'), + ] + updates = t.messages_to_updates(msgs2) + tool_updates = [u for u in updates + if getattr(u, 'session_update', '') == 'tool_call_update'] + assert len(tool_updates) == 1 + assert tool_updates[0].status == 'failed' + + def test_success_result_sets_completed_status(self): + t = ACPTranslator() + tc = { + 'id': 'tc_ok', + 'tool_name': 'web_search', + 'arguments': '{"query": "test"}', + } + msgs = [Message(role='assistant', content='', tool_calls=[tc])] + t.messages_to_updates(msgs) + + msgs2 = [ + Message(role='assistant', content='', tool_calls=[tc]), + Message(role='tool', + content='Found 5 results about test...', + tool_call_id='tc_ok', name='web_search'), + ] + updates = t.messages_to_updates(msgs2) + tool_updates = [u for u in updates + if getattr(u, 'session_update', '') == 'tool_call_update'] + assert len(tool_updates) == 1 + assert tool_updates[0].status == 'completed' + + +class TestMultiMessageTracking: + """Ensure the translator processes ALL new messages, not just the last.""" + + def test_processes_multiple_tool_results(self): + t = ACPTranslator() + tc1 = {'id': 'tc_a', 'tool_name': 'web_search', + 'arguments': '{"query":"a"}'} + tc2 = {'id': 'tc_b', 'tool_name': 'web_search', + 'arguments': '{"query":"b"}'} + + msgs = [ + Message(role='assistant', content='Searching...', + tool_calls=[tc1, tc2]), + ] + u1 = t.messages_to_updates(msgs) + assert len([u for u in u1 + if getattr(u, 'session_update', '') == 'tool_call']) == 2 + + msgs.append(Message(role='tool', content='Result A', + tool_call_id='tc_a', name='web_search')) + msgs.append(Message(role='tool', content='Result B', + tool_call_id='tc_b', name='web_search')) + u2 = t.messages_to_updates(msgs) + + tool_updates = [u for u in u2 + if getattr(u, 'session_update', '') == 'tool_call_update'] + assert len(tool_updates) == 2, ( + f'Expected 2 tool_call_updates, got {len(tool_updates)}') + + def test_assistant_deltas_tracked_across_chunks(self): + t = ACPTranslator() + msgs = [Message(role='assistant', content='He')] + u1 = t.messages_to_updates(msgs) + text_updates_1 = [u for u in u1 + if getattr(u, 'session_update', '') == 'agent_message_chunk'] + assert len(text_updates_1) == 1 + + msgs[0] = Message(role='assistant', content='Hello world') + u2 = t.messages_to_updates(msgs) + text_updates_2 = [u for u in u2 + if getattr(u, 'session_update', '') == 'agent_message_chunk'] + assert len(text_updates_2) == 1 + + def test_no_duplicate_tool_call_starts(self): + t = ACPTranslator() + tc = {'id': 'tc_dup', 'tool_name': 'web_search', + 'arguments': '{"query": "test"}'} + + msgs = [Message(role='assistant', content='', tool_calls=[tc])] + u1 = t.messages_to_updates(msgs) + starts_1 = [u for u in u1 + if getattr(u, 'session_update', '') == 'tool_call'] + assert len(starts_1) == 1 + + u2 = t.messages_to_updates(msgs) + starts_2 = [u for u in u2 + if getattr(u, 'session_update', '') == 'tool_call'] + assert len(starts_2) == 0 + + def test_no_duplicate_tool_completions(self): + t = ACPTranslator() + tc = {'id': 'tc_nodupe', 'tool_name': 'web_search', + 'arguments': '{}'} + msgs = [ + Message(role='assistant', content='', tool_calls=[tc]), + Message(role='tool', content='result', + tool_call_id='tc_nodupe', name='web_search'), + ] + u1 = t.messages_to_updates(msgs) + completions_1 = [u for u in u1 + if getattr(u, 'session_update', '') == 'tool_call_update'] + assert len(completions_1) == 1 + + u2 = t.messages_to_updates(msgs) + completions_2 = [u for u in u2 + if getattr(u, 'session_update', '') == 'tool_call_update'] + assert len(completions_2) == 0 + + +class TestPlanUpdates: + """Test plan extraction from todo tool results.""" + + def test_extract_from_todo_write_output(self): + from ms_agent.acp.server import MSAgentACPServer + + todo_result = json.dumps({ + 'status': 'ok', + 'plan_path': 'plan.json', + 'todos': [ + {'id': 'T1', 'content': 'Search papers', 'status': 'in_progress', + 'priority': 'high'}, + {'id': 'T2', 'content': 'Analyze results', 'status': 'pending', + 'priority': 'medium'}, + {'id': 'T3', 'content': 'Write report', 'status': 'pending', + 'priority': 'medium'}, + ], + }) + + session = MagicMock() + session.agent.runtime = None + session.messages = [ + Message(role='user', content='Research quantum computing'), + Message(role='assistant', content='', tool_calls=[ + {'id': 'tc_todo', 'tool_name': 'todo_write', 'arguments': '{}'} + ]), + Message(role='tool', content=todo_result, + tool_call_id='tc_todo', name='todo_write'), + ] + + translator = ACPTranslator() + plans = MSAgentACPServer._extract_plan_updates(session, translator) + assert len(plans) == 1 + plan = plans[0] + assert hasattr(plan, 'entries') + assert len(plan.entries) == 3 + assert plan.entries[0].content == 'Search papers' + + def test_no_plan_when_no_todos(self): + from ms_agent.acp.server import MSAgentACPServer + + session = MagicMock() + session.agent.runtime = None + session.messages = [ + Message(role='user', content='Hello'), + Message(role='assistant', content='Hi there!'), + ] + + translator = ACPTranslator() + plans = MSAgentACPServer._extract_plan_updates(session, translator) + assert plans == [] + + +class TestPermissionSchema: + """Test permission response format matches ACP schema.""" + + def test_auto_approve_returns_valid_response(self): + from ms_agent.acp.client import _CollectorClient + from acp.schema import ( + PermissionOption, RequestPermissionResponse, + AllowedOutcome, DeniedOutcome, + ) + + client = _CollectorClient(permission_policy='auto_approve') + options = [ + PermissionOption(option_id='allow_once', name='Allow', + kind='allow_once'), + PermissionOption(option_id='deny_once', name='Deny', + kind='reject_once'), + ] + tool_call = MagicMock() + + result = asyncio.get_event_loop().run_until_complete( + client.request_permission(options, 'ses_1', tool_call)) + + assert isinstance(result, RequestPermissionResponse) + assert isinstance(result.outcome, AllowedOutcome) + assert result.outcome.outcome == 'selected' + assert result.outcome.option_id == 'allow_once' + + def test_deny_returns_cancelled(self): + from ms_agent.acp.client import _CollectorClient + from acp.schema import ( + PermissionOption, RequestPermissionResponse, DeniedOutcome, + ) + + client = _CollectorClient(permission_policy='deny_all') + options = [ + PermissionOption(option_id='allow_once', name='Allow', + kind='allow_once'), + ] + tool_call = MagicMock() + + result = asyncio.get_event_loop().run_until_complete( + client.request_permission( + [o for o in options if 'deny' in o.kind], + 'ses_1', tool_call)) + + assert isinstance(result, RequestPermissionResponse) + assert isinstance(result.outcome, DeniedOutcome) + + +class TestFullMessageFlow: + """Simulate a complete message flow and verify all updates are correct.""" + + def test_full_agent_turn_with_tool_call(self): + t = ACPTranslator() + + all_updates = [] + + msgs = [Message(role='assistant', content='Let me ')] + all_updates.extend(t.messages_to_updates(msgs)) + + msgs[0] = Message(role='assistant', content='Let me search for that.') + all_updates.extend(t.messages_to_updates(msgs)) + + tc = {'id': 'tc_search', 'tool_name': 'web_search', + 'arguments': '{"query": "test topic"}'} + msgs[0] = Message( + role='assistant', + content='Let me search for that.', + tool_calls=[tc], + ) + all_updates.extend(t.messages_to_updates(msgs)) + + msgs.append(Message( + role='tool', content='Found 3 results...', + tool_call_id='tc_search', name='web_search', + )) + all_updates.extend(t.messages_to_updates(msgs)) + + msg_chunks = [u for u in all_updates + if getattr(u, 'session_update', '') == 'agent_message_chunk'] + tool_starts = [u for u in all_updates + if getattr(u, 'session_update', '') == 'tool_call'] + tool_updates = [u for u in all_updates + if getattr(u, 'session_update', '') == 'tool_call_update'] + + assert len(msg_chunks) >= 2, 'Should have streamed text incrementally' + assert len(tool_starts) == 1, 'Should have exactly one tool_call start' + assert len(tool_updates) == 1, 'Should have exactly one tool_call_update' + assert tool_starts[0].kind == 'search' + assert tool_updates[0].status == 'completed' + + def test_full_file_edit_flow(self): + t = ACPTranslator() + + tc = {'id': 'tc_edit', 'tool_name': 'file_write', + 'arguments': json.dumps({ + 'path': '/project/src/app.py', + 'content': 'def main():\n print("hello")\n', + })} + + msgs = [ + Message(role='assistant', + content='I will create the file for you.', + tool_calls=[tc]), + ] + u1 = t.messages_to_updates(msgs) + + msgs.append(Message( + role='tool', content='Save file successfully.', + tool_call_id='tc_edit', name='file_write', + )) + u2 = t.messages_to_updates(msgs) + + tool_starts = [u for u in u1 + if getattr(u, 'session_update', '') == 'tool_call'] + assert len(tool_starts) == 1 + assert tool_starts[0].kind == 'edit' + assert 'Edit' in tool_starts[0].title + + tool_updates = [u for u in u2 + if getattr(u, 'session_update', '') == 'tool_call_update'] + assert len(tool_updates) == 1 + assert tool_updates[0].status == 'completed' + + has_diff = any( + getattr(c, 'type', '') == 'diff' + for c in (tool_updates[0].content or [])) + assert has_diff, 'File write should produce diff content' + + +class TestErrorDetection: + + def test_error_markers(self): + assert ACPTranslator._looks_like_error( + 'Error: file not found') is True + assert ACPTranslator._looks_like_error( + 'Traceback (most recent call last)...') is True + assert ACPTranslator._looks_like_error( + '{"success": false, "error": "timeout"}') is True + assert ACPTranslator._looks_like_error( + 'Operation failed due to timeout') is True + + def test_success_messages(self): + assert ACPTranslator._looks_like_error( + 'Save file successfully.') is False + assert ACPTranslator._looks_like_error( + 'Found 5 results about quantum computing') is False + assert ACPTranslator._looks_like_error('') is False + assert ACPTranslator._looks_like_error( + '{"success": true, "output": "hello"}') is False + + +class TestResetTurn: + + def test_reset_clears_all_state(self): + t = ACPTranslator() + t._last_content_len = 100 + t._last_reasoning_len = 50 + t._emitted_tool_ids.add('tc_1') + t._completed_tool_ids.add('tc_1') + t._tool_args_cache['tc_1'] = {'path': '/test'} + t._tool_name_cache['tc_1'] = 'file_write' + t._last_seen_msg_count = 5 + + t.reset_turn() + + assert t._last_content_len == 0 + assert t._last_reasoning_len == 0 + assert len(t._emitted_tool_ids) == 0 + assert len(t._completed_tool_ids) == 0 + assert len(t._tool_args_cache) == 0 + assert len(t._tool_name_cache) == 0 + assert t._last_seen_msg_count == 0 + + +class TestPromptToMessages: + + def test_resource_block(self): + block = MagicMock() + block.type = 'resource' + block.resource = MagicMock() + block.resource.text = 'file contents here' + block.resource.uri = 'file:///src/main.py' + msgs = ACPTranslator.prompt_to_messages([block]) + assert 'file:///src/main.py' in msgs[0].content + assert 'file contents here' in msgs[0].content + + def test_resource_link_block(self): + block = MagicMock() + block.type = 'resource_link' + block.uri = 'file:///README.md' + msgs = ACPTranslator.prompt_to_messages([block]) + assert 'file:///README.md' in msgs[0].content + + def test_image_block(self): + block = MagicMock() + block.type = 'image' + msgs = ACPTranslator.prompt_to_messages([block]) + assert 'Image' in msgs[0].content diff --git a/tests/test_acp/test_acp_proxy.py b/tests/test_acp/test_acp_proxy.py new file mode 100644 index 000000000..1081dd7a9 --- /dev/null +++ b/tests/test_acp/test_acp_proxy.py @@ -0,0 +1,255 @@ +"""Unit tests for the ACP proxy module.""" + +import asyncio +import os +import tempfile + +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from ms_agent.acp.proxy import ( + BackendConfig, + MSAgentACPProxy, + ProxyConfig, + _RelayClient, +) +from ms_agent.acp.proxy_session import ProxySessionEntry, ProxySessionStore +from ms_agent.acp.errors import ( + ConfigError, + MaxSessionsError, + SessionNotFoundError, +) + + +class TestProxyConfig: + + def test_from_yaml_basic(self, tmp_path): + cfg_file = tmp_path / 'proxy.yaml' + cfg_file.write_text(""" +proxy: + max_sessions: 4 + session_timeout: 1800 + default_backend: agent-a + +backends: + agent-a: + command: echo + args: [hello] + description: "Test agent A" + agent-b: + command: cat + description: "Test agent B" +""") + config = ProxyConfig.from_yaml(str(cfg_file)) + assert config.max_sessions == 4 + assert config.session_timeout == 1800 + assert config.default_backend == 'agent-a' + assert len(config.backends) == 2 + assert config.backends['agent-a'].command == 'echo' + assert config.backends['agent-a'].args == ['hello'] + assert config.backends['agent-b'].args == [] + + def test_from_yaml_default_backend_auto(self, tmp_path): + cfg_file = tmp_path / 'proxy.yaml' + cfg_file.write_text(""" +backends: + only-one: + command: true +""") + config = ProxyConfig.from_yaml(str(cfg_file)) + assert config.default_backend == 'only-one' + assert config.max_sessions == 8 + + def test_from_yaml_missing_command_skipped(self, tmp_path): + cfg_file = tmp_path / 'proxy.yaml' + cfg_file.write_text(""" +backends: + good: + command: echo + bad: + description: "no command field" +""") + config = ProxyConfig.from_yaml(str(cfg_file)) + assert 'good' in config.backends + assert 'bad' not in config.backends + + def test_from_yaml_not_found(self): + with pytest.raises(ConfigError): + ProxyConfig.from_yaml('/nonexistent/proxy.yaml') + + def test_from_yaml_empty_file(self, tmp_path): + cfg_file = tmp_path / 'empty.yaml' + cfg_file.write_text('') + with pytest.raises(ConfigError, match='Invalid'): + ProxyConfig.from_yaml(str(cfg_file)) + + def test_from_yaml_with_env(self, tmp_path): + cfg_file = tmp_path / 'proxy.yaml' + cfg_file.write_text(""" +backends: + myagent: + command: agent + env: + MY_KEY: my_value +""") + config = ProxyConfig.from_yaml(str(cfg_file)) + assert config.backends['myagent'].env == {'MY_KEY': 'my_value'} + + +class TestProxySessionStore: + + def test_register_and_get(self): + store = ProxySessionStore(max_sessions=4) + entry = store.register( + backend_name='test', + backend_sid='bk_123', + backend_conn=MagicMock(), + backend_proc=MagicMock(), + ctx_manager=None, + cwd='/tmp', + ) + assert entry.id.startswith('pxy_') + assert entry.backend_name == 'test' + assert entry.backend_sid == 'bk_123' + + fetched = store.get(entry.id) + assert fetched is entry + + def test_get_not_found(self): + store = ProxySessionStore() + with pytest.raises(SessionNotFoundError): + store.get('nonexistent') + + def test_max_sessions_eviction(self): + store = ProxySessionStore(max_sessions=2) + e1 = store.register( + 'a', 'sid1', MagicMock(), MagicMock(), None, '/tmp') + e2 = store.register( + 'b', 'sid2', MagicMock(), MagicMock(), None, '/tmp') + e3 = store.register( + 'c', 'sid3', MagicMock(), MagicMock(), None, '/tmp') + assert len(store._sessions) == 2 + assert e3.id in store._sessions + + def test_max_sessions_all_running(self): + store = ProxySessionStore(max_sessions=1) + e1 = store.register( + 'a', 'sid1', MagicMock(), MagicMock(), None, '/tmp') + e1.is_running = True + with pytest.raises(MaxSessionsError): + store.register( + 'b', 'sid2', MagicMock(), MagicMock(), None, '/tmp') + + def test_list_sessions(self): + store = ProxySessionStore() + store.register('a', 'sid1', MagicMock(), MagicMock(), None, '/tmp') + store.register('b', 'sid2', MagicMock(), MagicMock(), None, '/work') + result = store.list_sessions() + assert len(result) == 2 + backends = {e['backend'] for e in result} + assert backends == {'a', 'b'} + + @pytest.mark.asyncio + async def test_remove(self): + store = ProxySessionStore() + entry = store.register( + 'a', 'sid1', MagicMock(), MagicMock(), None, '/tmp') + await store.remove(entry.id) + assert entry.id not in store._sessions + + @pytest.mark.asyncio + async def test_close_all(self): + store = ProxySessionStore() + store.register('a', 'sid1', MagicMock(), MagicMock(), None, '/tmp') + store.register('b', 'sid2', MagicMock(), MagicMock(), None, '/tmp') + await store.close_all() + assert len(store._sessions) == 0 + + def test_cancel(self): + store = ProxySessionStore() + entry = store.register( + 'a', 'sid1', MagicMock(), MagicMock(), None, '/tmp') + assert not entry.cancelled + entry.request_cancel() + assert entry.cancelled + + +class TestRelayClient: + + @pytest.mark.asyncio + async def test_session_update_relay(self): + mock_conn = AsyncMock() + relay = _RelayClient(mock_conn, 'pxy_abc') + update = MagicMock() + + await relay.session_update('backend_sid', update) + mock_conn.session_update.assert_awaited_once_with('pxy_abc', update) + + @pytest.mark.asyncio + async def test_request_permission_relay(self): + mock_conn = AsyncMock() + relay = _RelayClient(mock_conn, 'pxy_abc') + options = [MagicMock()] + tool_call = MagicMock() + + await relay.request_permission(options, 'backend_sid', tool_call) + mock_conn.request_permission.assert_awaited_once_with( + session_id='pxy_abc', + tool_call=tool_call, + options=options, + ) + + +class TestMSAgentACPProxy: + + def _make_config(self): + return ProxyConfig( + max_sessions=4, + session_timeout=3600, + default_backend='agent-a', + backends={ + 'agent-a': BackendConfig( + name='agent-a', + command='echo', + args=['acp'], + description='Test A', + ), + 'agent-b': BackendConfig( + name='agent-b', + command='cat', + description='Test B', + ), + }, + ) + + @pytest.mark.asyncio + async def test_initialize(self): + proxy = MSAgentACPProxy(self._make_config()) + resp = await proxy.initialize(protocol_version=1) + assert resp.agent_info.name == 'ms-agent-proxy' + assert resp.protocol_version >= 1 + + def test_build_config_options_multiple_backends(self): + proxy = MSAgentACPProxy(self._make_config()) + opts = proxy._build_config_options('agent-a') + assert opts is not None + assert len(opts) == 1 + assert opts[0].id == 'backend' + assert opts[0].current_value == 'agent-a' + + def test_build_config_options_single_backend(self): + config = ProxyConfig( + backends={ + 'only': BackendConfig(name='only', command='echo'), + }, + default_backend='only', + ) + proxy = MSAgentACPProxy(config) + opts = proxy._build_config_options('only') + assert opts is None + + @pytest.mark.asyncio + async def test_list_sessions_empty(self): + proxy = MSAgentACPProxy(self._make_config()) + resp = await proxy.list_sessions() + assert resp.sessions == [] diff --git a/tests/test_acp/test_acp_proxy_e2e.py b/tests/test_acp/test_acp_proxy_e2e.py new file mode 100644 index 000000000..74918c0e4 --- /dev/null +++ b/tests/test_acp/test_acp_proxy_e2e.py @@ -0,0 +1,291 @@ +"""End-to-end test for the ACP proxy with a real opencode backend. + +Spawns the ms-agent ACP proxy process over stdio, which in turn spawns +opencode as a backend ACP agent. Validates the full lifecycle: + initialize -> new_session -> prompt -> response streaming + +Requires ``opencode`` to be installed and available in PATH. +""" + +import asyncio +import os +import shutil +import sys +import tempfile +from typing import Any + +import pytest + +_SKIP_REASON = None +try: + from acp import spawn_agent_process, text_block + from acp.interfaces import Client + from acp.schema import ( + AllowedOutcome, + DeniedOutcome, + RequestPermissionResponse, + ) +except ImportError: + _SKIP_REASON = 'agent-client-protocol not installed' + +_REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +_PROXY_CONFIG = os.path.join( + os.path.dirname(os.path.abspath(__file__)), 'proxy_opencode.yaml') + + +def _have_opencode() -> bool: + return shutil.which('opencode') is not None + + +class _TestClient(Client): + """ACP client that collects all streamed updates for assertion.""" + + def __init__(self): + self.updates: list = [] + self.text_chunks: list[str] = [] + self.thought_chunks: list[str] = [] + self.tool_calls: list = [] + self.tool_results: list = [] + self.permission_requests: list = [] + self.update_types: list[str] = [] + + async def session_update(self, session_id: str, update: Any, + **kwargs: Any) -> None: + self.updates.append(update) + update_type = getattr(update, 'session_update', None) + if update_type: + self.update_types.append(update_type) + + if update_type == 'agent_message_chunk': + content = getattr(update, 'content', None) + if content is not None: + text = getattr(content, 'text', None) or str(content) + self.text_chunks.append(text) + elif update_type == 'agent_thought_chunk': + content = getattr(update, 'content', None) + if content is not None: + text = getattr(content, 'text', None) or str(content) + self.thought_chunks.append(text) + elif update_type == 'tool_call_start': + self.tool_calls.append(update) + elif update_type == 'tool_call_update': + self.tool_results.append(update) + + async def request_permission(self, options: list, session_id: str, + tool_call: Any, **kwargs: Any) -> Any: + self.permission_requests.append({ + 'session_id': session_id, + 'tool_call': tool_call, + }) + allow = next( + (o for o in options + if 'allow' in (getattr(o, 'kind', '') or '')), + None, + ) + if allow: + return RequestPermissionResponse( + outcome=AllowedOutcome( + outcome='selected', + option_id=getattr(allow, 'option_id', 'allow_once'), + ) + ) + return RequestPermissionResponse( + outcome=DeniedOutcome(outcome='cancelled') + ) + + @property + def collected_text(self) -> str: + return ''.join(self.text_chunks) + + @property + def collected_thought(self) -> str: + return ''.join(self.thought_chunks) + + +async def _spawn_proxy(client): + """Spawn the proxy process, yielding (conn, proc). + + Wraps ``spawn_agent_process`` so that the SDK's queue-closed race + during teardown does not fail the test. + """ + ctx = spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp-proxy', + '--config', _PROXY_CONFIG, + ) + conn, proc = await ctx.__aenter__() + try: + yield conn, proc + finally: + try: + await ctx.__aexit__(None, None, None) + except RuntimeError: + pass + + +@pytest.mark.skipif( + _SKIP_REASON is not None, reason=_SKIP_REASON or '') +@pytest.mark.skipif( + not _have_opencode(), reason='opencode not installed') +@pytest.mark.asyncio +async def test_proxy_initialize_and_new_session(): + """Proxy boots, negotiates protocol, and creates a session via opencode.""" + client = _TestClient() + async for conn, _proc in _spawn_proxy(client): + resp = await conn.initialize(protocol_version=1) + assert resp.protocol_version >= 1 + assert resp.agent_info is not None + assert resp.agent_info.name == 'ms-agent-proxy' + + session = await conn.new_session(cwd=os.getcwd(), mcp_servers=[]) + assert session.session_id + assert session.session_id.startswith('pxy_') + break + + +@pytest.mark.skipif( + _SKIP_REASON is not None, reason=_SKIP_REASON or '') +@pytest.mark.skipif( + not _have_opencode(), reason='opencode not installed') +@pytest.mark.asyncio +async def test_proxy_prompt_streaming(): + """Send a trivial prompt through the proxy to opencode and verify + that streamed text is relayed back.""" + client = _TestClient() + async for conn, _proc in _spawn_proxy(client): + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd=os.getcwd(), mcp_servers=[]) + sid = session.session_id + + prompt_resp = await asyncio.wait_for( + conn.prompt( + session_id=sid, + prompt=[text_block( + 'Reply with exactly: PROXY_TEST_OK. ' + 'Do not include any other text.' + )], + ), + timeout=120, + ) + + assert prompt_resp is not None + assert len(client.updates) > 0, 'No session updates received' + assert len(client.text_chunks) > 0, 'No text chunks relayed' + assert len(client.collected_text) > 0, 'Collected text is empty' + print(f'\n--- Proxy E2E (simple) ---') + print(f'Updates received: {len(client.updates)}') + print(f'Text chunks: {len(client.text_chunks)}') + print(f'Collected text: {client.collected_text[:200]}') + print(f'Stop reason: {prompt_resp.stop_reason}') + break + + +@pytest.mark.skipif( + _SKIP_REASON is not None, reason=_SKIP_REASON or '') +@pytest.mark.skipif( + not _have_opencode(), reason='opencode not installed') +@pytest.mark.asyncio +async def test_proxy_real_task_with_tools(): + """Give opencode a real task that requires tool use (reading a file + and analyzing its content), then verify that tool_call events are + properly relayed through the proxy.""" + + with tempfile.TemporaryDirectory() as tmpdir: + target_file = os.path.join(tmpdir, 'data.csv') + with open(target_file, 'w') as f: + f.write('name,score\n') + f.write('Alice,92\n') + f.write('Bob,85\n') + f.write('Charlie,78\n') + f.write('Diana,95\n') + f.write('Eve,88\n') + + client = _TestClient() + async for conn, _proc in _spawn_proxy(client): + await conn.initialize(protocol_version=1) + session = await conn.new_session( + cwd=tmpdir, mcp_servers=[]) + sid = session.session_id + + prompt_resp = await asyncio.wait_for( + conn.prompt( + session_id=sid, + prompt=[text_block( + f'Read the file {target_file} and tell me: ' + f'who has the highest score? ' + f'Reply with just the name and score.' + )], + ), + timeout=120, + ) + + assert prompt_resp is not None + + unique_types = set(client.update_types) + print(f'\n--- Proxy E2E (real task) ---') + print(f'Total updates: {len(client.updates)}') + print(f'Update types seen: {sorted(unique_types)}') + print(f'Tool calls started: {len(client.tool_calls)}') + print(f'Tool results: {len(client.tool_results)}') + print(f'Permission requests: {len(client.permission_requests)}') + print(f'Thought chunks: {len(client.thought_chunks)}') + print(f'Text chunks: {len(client.text_chunks)}') + print(f'Response text: {client.collected_text[:300]}') + print(f'Stop reason: {prompt_resp.stop_reason}') + + print(f'\n--- Raw updates dump ---') + for i, u in enumerate(client.updates): + utype = getattr(u, 'session_update', '?') + content = getattr(u, 'content', None) + text = None + if content is not None: + text = getattr(content, 'text', None) + if text is None: + text = str(content)[:200] + print(f' [{i}] type={utype}') + print(f' content_text={text[:200] if text else None}') + for attr in ('tool_call_id', 'call_id', 'name', + 'tool_name', 'arguments'): + val = getattr(u, attr, None) + if val is not None: + print(f' {attr}={str(val)[:100]}') + + assert len(client.updates) > 0, 'No updates received' + + all_text = [] + for u in client.updates: + utype = getattr(u, 'session_update', None) + if utype in ('agent_message_chunk', 'user_message_chunk'): + content = getattr(u, 'content', None) + if content is not None: + t = getattr(content, 'text', None) or str(content) + all_text.append(t) + full_response = ''.join(all_text) + print(f'\nFull response (agent+user chunks): ' + f'{full_response[:300]}') + + assert len(full_response) > 0, ( + f'No response text in any chunk type. ' + f'Types: {sorted(unique_types)}' + ) + + response_lower = full_response.lower() + assert 'diana' in response_lower or '95' in response_lower, ( + f'Expected Diana/95 in response, got: ' + f'{full_response[:200]}' + ) + + has_tool_activity = ( + len(client.tool_calls) > 0 + or len(client.tool_results) > 0 + or 'tool_call' in unique_types + or 'tool_call_update' in unique_types + ) + assert has_tool_activity, ( + f'Expected tool use for file reading, but got none. ' + f'Update types: {sorted(unique_types)}' + ) + break diff --git a/tests/test_acp/test_acp_real.py b/tests/test_acp/test_acp_real.py new file mode 100644 index 000000000..27325af19 --- /dev/null +++ b/tests/test_acp/test_acp_real.py @@ -0,0 +1,821 @@ +"""Real integration tests for ACP modules — NO mocks. + +Validates every ACP component using real objects, real SDK types, +real agent configs, and real subprocess spawning. +""" + +import asyncio +import json +import os +import sys +import tempfile +import time + +import pytest + +_REPO_ROOT = os.path.dirname( + os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) +_DEFAULT_CONFIG = os.path.join(_REPO_ROOT, 'ms_agent', 'agent', 'agent.yaml') +_ACP_TEST_CONFIG = os.environ.get('ACP_TEST_CONFIG', _DEFAULT_CONFIG) + +sys.path.insert(0, _REPO_ROOT) + + +# ====================================================================== +# 1. Error mapping — real exception objects, real RequestError +# ====================================================================== + +class TestErrorMappingReal: + + def test_all_custom_errors_have_correct_codes(self): + from ms_agent.acp.errors import ( + SessionNotFoundError, ResourceNotFoundError, LLMError, + RateLimitError, ConfigError, MaxSessionsError, + ) + cases = [ + (SessionNotFoundError('ses_x'), -32001), + (ResourceNotFoundError('/tmp/x'), -32002), + (LLMError('timeout'), -32003), + (RateLimitError('too fast'), -32004), + (ConfigError('bad yaml'), -32005), + (MaxSessionsError(4), -32006), + ] + for err, expected_code in cases: + assert err.code == expected_code, f'{type(err).__name__} code mismatch' + assert err.message, f'{type(err).__name__} should have a message' + assert isinstance(err.data, dict), f'{type(err).__name__} data should be dict' + + def test_wrap_agent_error_produces_real_request_error(self): + from acp import RequestError + from ms_agent.acp.errors import wrap_agent_error, LLMError + + rpc = wrap_agent_error(LLMError('model down')) + assert isinstance(rpc, RequestError) + assert rpc.code == -32003 + + def test_wrap_python_stdlib_exceptions(self): + from acp import RequestError + from ms_agent.acp.errors import wrap_agent_error + + for exc, expected_code in [ + (FileNotFoundError('/no'), -32002), + (PermissionError('denied'), -32000), + (TimeoutError('slow'), -32004), + (ValueError('bad'), -32602), + ]: + rpc = wrap_agent_error(exc) + assert isinstance(rpc, RequestError) + assert rpc.code == expected_code, f'{type(exc).__name__} mapped to wrong code' + + def test_wrap_unknown_exception_fallback(self): + from acp import RequestError + from ms_agent.acp.errors import wrap_agent_error + + rpc = wrap_agent_error(RuntimeError('surprise')) + assert isinstance(rpc, RequestError) + assert rpc.code == -32603 + + def test_wrap_already_request_error_passthrough(self): + from acp import RequestError + from ms_agent.acp.errors import wrap_agent_error + + original = RequestError(-32999, 'custom') + result = wrap_agent_error(original) + assert result is original + + +# ====================================================================== +# 2. Translator — real ACP SDK types, real Message objects +# ====================================================================== + +class TestTranslatorReal: + + def test_prompt_to_messages_with_real_text_block(self): + from acp import text_block + from ms_agent.acp.translator import ACPTranslator + from ms_agent.llm.utils import Message + + block = text_block('What is 2+2?') + msgs = ACPTranslator.prompt_to_messages([block]) + assert len(msgs) == 1 + assert msgs[0].role == 'user' + assert '2+2' in msgs[0].content + + def test_prompt_to_messages_appends_to_history(self): + from acp import text_block + from ms_agent.acp.translator import ACPTranslator + from ms_agent.llm.utils import Message + + history = [Message(role='system', content='You are helpful.')] + block = text_block('Hello') + result = ACPTranslator.prompt_to_messages([block], history) + assert len(result) == 2 + assert result[0].role == 'system' + assert result[1].role == 'user' + assert result is history + + def test_delta_tracking_content(self): + from ms_agent.acp.translator import ACPTranslator + from ms_agent.llm.utils import Message + + t = ACPTranslator() + + msgs1 = [Message(role='assistant', content='Hel')] + u1 = t.messages_to_updates(msgs1) + assert len(u1) == 1 + assert t._last_content_len == 3 + + msgs2 = [Message(role='assistant', content='Hello world')] + u2 = t.messages_to_updates(msgs2) + assert len(u2) == 1 + assert t._last_content_len == 11 + + def test_delta_tracking_reasoning(self): + from ms_agent.acp.translator import ACPTranslator + from ms_agent.llm.utils import Message + + t = ACPTranslator() + msgs = [Message(role='assistant', content='', reasoning_content='thinking step 1')] + u = t.messages_to_updates(msgs) + assert len(u) >= 1 + assert t._last_reasoning_len == len('thinking step 1') + + def test_tool_call_emitted_once(self): + from ms_agent.acp.translator import ACPTranslator + from ms_agent.llm.utils import Message + + t = ACPTranslator() + tc = {'id': 'tc_abc', 'type': 'function', 'tool_name': 'web_search', + 'arguments': '{"q": "test"}'} + msg = Message(role='assistant', content='', tool_calls=[tc]) + + u1 = t.messages_to_updates([msg]) + assert any(getattr(u, 'tool_call_id', None) == 'tc_abc' for u in u1) + + u2 = t.messages_to_updates([msg]) + assert not any(getattr(u, 'tool_call_id', None) == 'tc_abc' for u in u2) + + def test_tool_result_translation(self): + from ms_agent.acp.translator import ACPTranslator + from ms_agent.llm.utils import Message + + t = ACPTranslator() + t._emitted_tool_ids.add('tc_123') + msg = Message(role='tool', content='result data', tool_call_id='tc_123') + updates = t.messages_to_updates([msg]) + assert len(updates) == 1 + + def test_reset_turn_clears_state(self): + from ms_agent.acp.translator import ACPTranslator + + t = ACPTranslator() + t._last_content_len = 50 + t._last_reasoning_len = 30 + t._emitted_tool_ids.add('tc_1') + t._completed_tool_ids.add('tc_1') + t.reset_turn() + assert t._last_content_len == 0 + assert t._last_reasoning_len == 0 + assert len(t._emitted_tool_ids) == 0 + assert len(t._completed_tool_ids) == 0 + + def test_build_plan_update_produces_real_acp_type(self): + from ms_agent.acp.translator import ACPTranslator + from acp.schema import AgentPlanUpdate + + steps = [ + {'description': 'Step 1', 'status': 'completed', 'priority': 'high'}, + {'description': 'Step 2', 'status': 'in_progress'}, + {'description': 'Step 3', 'status': 'pending'}, + ] + update = ACPTranslator.build_plan_update(steps) + assert isinstance(update, AgentPlanUpdate) + assert len(update.entries) == 3 + + def test_map_stop_reason_end_turn(self): + from ms_agent.acp.translator import ACPTranslator + from types import SimpleNamespace + + session = SimpleNamespace( + cancelled=False, + agent=SimpleNamespace( + runtime=SimpleNamespace(round=3), + max_chat_round=20, + ), + ) + assert ACPTranslator.map_stop_reason(session) == 'end_turn' + + def test_map_stop_reason_max_rounds(self): + from ms_agent.acp.translator import ACPTranslator + from types import SimpleNamespace + + session = SimpleNamespace( + cancelled=False, + agent=SimpleNamespace( + runtime=SimpleNamespace(round=21), + max_chat_round=20, + ), + ) + assert ACPTranslator.map_stop_reason(session) == 'max_turn_requests' + + def test_map_stop_reason_cancelled(self): + from ms_agent.acp.translator import ACPTranslator + from types import SimpleNamespace + + session = SimpleNamespace( + cancelled=True, + agent=SimpleNamespace(runtime=SimpleNamespace(round=1), max_chat_round=20), + ) + assert ACPTranslator.map_stop_reason(session) == 'cancelled' + + def test_tool_kind_mapping(self): + from ms_agent.acp.translator import _TOOL_KIND_MAP + assert _TOOL_KIND_MAP['code_executor'] == 'execute' + assert _TOOL_KIND_MAP['web_search'] == 'search' + assert _TOOL_KIND_MAP['file_read'] == 'read' + assert _TOOL_KIND_MAP['file_write'] == 'edit' + assert _TOOL_KIND_MAP['todo'] == 'think' + + +# ====================================================================== +# 3. Config module — real OmegaConf objects +# ====================================================================== + +class TestConfigReal: + + def test_build_config_options_with_model(self): + from omegaconf import OmegaConf + from ms_agent.acp.config import build_config_options + from acp.schema import SessionConfigOptionSelect + + cfg = OmegaConf.create({'llm': {'model': 'qwen-max'}}) + opts = build_config_options(cfg) + assert opts is not None + assert len(opts) == 1 + assert isinstance(opts[0], SessionConfigOptionSelect) + assert opts[0].id == 'model' + assert opts[0].current_value == 'qwen-max' + + def test_build_config_options_with_available_models(self): + from omegaconf import OmegaConf + from ms_agent.acp.config import build_config_options + + cfg = OmegaConf.create({'llm': {'model': 'gpt-4o'}}) + opts = build_config_options(cfg, available_models=['gpt-4o', 'gpt-4o-mini', 'o1']) + assert opts is not None + assert len(opts[0].options) == 3 + + def test_build_config_options_no_model(self): + from omegaconf import OmegaConf + from ms_agent.acp.config import build_config_options + + cfg = OmegaConf.create({'other': 'value'}) + assert build_config_options(cfg) is None + + def test_apply_config_option_model(self): + from omegaconf import OmegaConf + from ms_agent.acp.config import apply_config_option + + cfg = OmegaConf.create({'llm': {'model': 'qwen-max'}}) + assert apply_config_option(cfg, 'model', 'gpt-4o') is True + assert cfg.llm.model == 'gpt-4o' + + def test_apply_config_option_unknown_id(self): + from omegaconf import OmegaConf + from ms_agent.acp.config import apply_config_option + + cfg = OmegaConf.create({'llm': {'model': 'qwen-max'}}) + assert apply_config_option(cfg, 'temperature', '0.5') is False + + def test_build_session_modes(self): + from ms_agent.acp.config import build_session_modes + from acp.schema import SessionModeState + + modes = build_session_modes() + assert isinstance(modes, SessionModeState) + assert modes.current_mode_id == 'agent' + assert len(modes.available_modes) >= 1 + + +# ====================================================================== +# 4. Permissions — real PermissionPolicy objects +# ====================================================================== + +class TestPermissionsReal: + + def test_auto_approve_flow(self): + from ms_agent.acp.permissions import PermissionPolicy + + p = PermissionPolicy('auto_approve') + assert p.should_ask('code_executor') is False + assert p.auto_decision('code_executor') == 'allow_once' + p.record_choice('code_executor', True) + assert p.auto_decision('code_executor') == 'allow_once' + + def test_always_ask_flow(self): + from ms_agent.acp.permissions import PermissionPolicy + + p = PermissionPolicy('always_ask') + assert p.should_ask('web_search') is True + assert p.auto_decision('web_search') is None + p.record_choice('web_search', True) + assert p.should_ask('web_search') is True + + def test_remember_choice_allow(self): + from ms_agent.acp.permissions import PermissionPolicy + + p = PermissionPolicy('remember_choice') + assert p.should_ask('read_file') is True + p.record_choice('read_file', True) + assert p.should_ask('read_file') is False + dec = p.auto_decision('read_file') + assert dec is not None and 'allow' in dec + + def test_remember_choice_deny(self): + from ms_agent.acp.permissions import PermissionPolicy + + p = PermissionPolicy('remember_choice') + p.record_choice('dangerous', False) + assert p.should_ask('dangerous') is False + dec = p.auto_decision('dangerous') + assert dec is not None and 'deny' in dec + + def test_reset_clears_remembered(self): + from ms_agent.acp.permissions import PermissionPolicy + + p = PermissionPolicy('remember_choice') + p.record_choice('tool_a', True) + p.record_choice('tool_b', False) + p.reset() + assert p.should_ask('tool_a') is True + assert p.should_ask('tool_b') is True + + @pytest.mark.asyncio + async def test_request_tool_permission_auto(self): + from ms_agent.acp.permissions import PermissionPolicy, request_tool_permission + + policy = PermissionPolicy('auto_approve') + result = await request_tool_permission( + connection=None, + session_id='ses_test', + tool_call_id='tc_1', + tool_name='web_search', + policy=policy, + ) + assert result is True + + +# ====================================================================== +# 5. Registry — real manifest generation + file write +# ====================================================================== + +class TestRegistryReal: + + def test_manifest_structure(self): + from ms_agent.acp.registry import generate_agent_manifest + + m = generate_agent_manifest( + config_path='/path/to/researcher.yaml', + output_path=None, + ) + assert m['name'] == 'ms-agent' + assert m['protocol'] == 'acp' + assert m['protocolVersion'] == 1 + assert m['transport']['type'] == 'stdio' + assert m['transport']['command'] == 'ms-agent' + assert '--config' in m['transport']['args'] + assert '/path/to/researcher.yaml' in m['transport']['args'] + assert 'capabilities' in m + + def test_manifest_without_config(self): + from ms_agent.acp.registry import generate_agent_manifest + + m = generate_agent_manifest(output_path=None) + assert m['transport']['args'] == ['acp'] + + def test_manifest_write_to_file(self): + from ms_agent.acp.registry import generate_agent_manifest + + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + path = f.name + + try: + m = generate_agent_manifest( + config_path='/test/config.yaml', + output_path=path, + title='Test Agent', + ) + with open(path) as f: + written = json.load(f) + assert written['title'] == 'Test Agent' + assert written['name'] == 'ms-agent' + finally: + os.unlink(path) + + def test_manifest_custom_fields(self): + from ms_agent.acp.registry import generate_agent_manifest + + m = generate_agent_manifest( + output_path=None, + version='1.2.3', + title='Custom Agent', + description='Custom description', + ) + assert m['version'] == '1.2.3' + assert m['title'] == 'Custom Agent' + assert m['description'] == 'Custom description' + + +# ====================================================================== +# 6. SessionStore — real store object (no agent creation, tests structure) +# ====================================================================== + +class TestSessionStoreReal: + + def test_empty_store(self): + from ms_agent.acp.session_store import ACPSessionStore + + store = ACPSessionStore(max_sessions=4) + assert store.list_sessions() == [] + assert store.max_sessions == 4 + + def test_get_nonexistent(self): + from ms_agent.acp.session_store import ACPSessionStore + from ms_agent.acp.errors import SessionNotFoundError + + store = ACPSessionStore() + with pytest.raises(SessionNotFoundError) as exc_info: + store.get('ses_does_not_exist') + assert exc_info.value.code == -32001 + + @pytest.mark.asyncio + async def test_create_with_invalid_config(self): + from ms_agent.acp.session_store import ACPSessionStore + from ms_agent.acp.errors import ConfigError + + store = ACPSessionStore() + with pytest.raises(ConfigError): + await store.create(config_path='/nonexistent/agent.yaml', cwd='/tmp') + + @pytest.mark.asyncio + async def test_create_real_session(self, monkeypatch): + """Create a real session with the default agent config.""" + if not os.path.isfile(_ACP_TEST_CONFIG): + pytest.skip('No agent config available') + + monkeypatch.setattr(sys, 'argv', ['test']) + + from ms_agent.acp.session_store import ACPSessionStore, ACPSessionEntry + + store = ACPSessionStore() + try: + entry = await store.create(config_path=_ACP_TEST_CONFIG, cwd='/tmp') + assert isinstance(entry, ACPSessionEntry) + assert entry.id.startswith('ses_') + assert entry.agent is not None + assert entry.cwd == '/tmp' + assert len(store.list_sessions()) == 1 + + retrieved = store.get(entry.id) + assert retrieved.id == entry.id + finally: + await store.close_all() + + @pytest.mark.asyncio + async def test_lru_eviction(self, monkeypatch): + """Test that LRU eviction works when max_sessions is reached.""" + if not os.path.isfile(_ACP_TEST_CONFIG): + pytest.skip('No agent config available') + + monkeypatch.setattr(sys, 'argv', ['test']) + + from ms_agent.acp.session_store import ACPSessionStore + + store = ACPSessionStore(max_sessions=2) + try: + s1 = await store.create(config_path=_ACP_TEST_CONFIG, cwd='/tmp') + s2 = await store.create(config_path=_ACP_TEST_CONFIG, cwd='/tmp') + assert len(store.list_sessions()) == 2 + + s3 = await store.create(config_path=_ACP_TEST_CONFIG, cwd='/tmp') + assert len(store.list_sessions()) == 2 + assert s3.id in [s['session_id'] for s in store.list_sessions()] + finally: + await store.close_all() + + +# ====================================================================== +# 7. ACPAgentTool — real tool object creation +# ====================================================================== + +class TestACPAgentToolReal: + + def test_from_config_returns_none_without_acp_agents(self): + from omegaconf import OmegaConf + from ms_agent.tools.acp_agent_tool import ACPAgentTool + + cfg = OmegaConf.create({'llm': {'model': 'test'}}) + assert ACPAgentTool.from_config(cfg) is None + + def test_from_config_creates_tool(self): + from omegaconf import OmegaConf + from ms_agent.tools.acp_agent_tool import ACPAgentTool + + cfg = OmegaConf.create({ + 'llm': {'model': 'test'}, + 'acp_agents': { + 'codex': { + 'command': 'codex', + 'args': ['mcp-server'], + 'description': 'Codex coding agent', + }, + }, + }) + tool = ACPAgentTool.from_config(cfg) + assert tool is not None + assert 'codex' in tool._client_manager.list_agents() + + @pytest.mark.asyncio + async def test_get_tools_structure(self): + from omegaconf import OmegaConf + from ms_agent.tools.acp_agent_tool import ACPAgentTool + + cfg = OmegaConf.create({ + 'llm': {'model': 'test'}, + 'acp_agents': { + 'agent_a': { + 'command': 'echo', + 'args': [], + 'description': 'Agent A desc', + }, + 'agent_b': { + 'command': 'echo', + 'args': [], + 'description': 'Agent B desc', + }, + }, + }) + tool = ACPAgentTool(cfg, acp_agents_config=OmegaConf.to_container(cfg.acp_agents, resolve=True)) + tools = await tool._get_tools_inner() + assert 'acp_agent_a' in tools + assert 'acp_agent_b' in tools + assert tools['acp_agent_a'][0]['parameters']['required'] == ['query'] + + +# ====================================================================== +# 8. ACP Client Manager — real object, config-based +# ====================================================================== + +class TestACPClientManagerReal: + + def test_empty_manager(self): + from ms_agent.acp.client import ACPClientManager + + mgr = ACPClientManager() + assert mgr.list_agents() == [] + + def test_configured_manager(self): + from ms_agent.acp.client import ACPClientManager + + cfg = { + 'codex': { + 'command': 'codex', + 'args': ['mcp-server'], + 'description': 'Codex', + 'permission_policy': 'auto_approve', + }, + } + mgr = ACPClientManager(cfg) + assert 'codex' in mgr.list_agents() + + @pytest.mark.asyncio + async def test_call_unconfigured_returns_error(self): + from ms_agent.acp.client import ACPClientManager + + mgr = ACPClientManager() + result = await mgr.call_agent('nonexistent', 'hello') + assert 'not configured' in result + + +# ====================================================================== +# 9. CLI — acp-registry command real execution +# ====================================================================== + +class TestCLIReal: + + def test_acp_registry_generates_json(self): + """Run 'ms-agent acp-registry' as a subprocess and verify output.""" + import subprocess + + with tempfile.NamedTemporaryFile(suffix='.json', delete=False) as f: + outpath = f.name + + try: + result = subprocess.run( + [sys.executable, '-m', 'ms_agent.cli.cli', + 'acp-registry', '--output', outpath, '--title', 'TestAgent'], + capture_output=True, text=True, timeout=30, + cwd=_REPO_ROOT, + ) + assert result.returncode == 0, f'stderr: {result.stderr}' + stdout_json = json.loads(result.stdout) + assert stdout_json['name'] == 'ms-agent' + assert stdout_json['title'] == 'TestAgent' + + with open(outpath) as f: + file_json = json.load(f) + assert file_json['protocol'] == 'acp' + finally: + if os.path.exists(outpath): + os.unlink(outpath) + + def test_acp_registry_with_config(self): + import subprocess + + result = subprocess.run( + [sys.executable, '-m', 'ms_agent.cli.cli', + 'acp-registry', '--config', _ACP_TEST_CONFIG, '--output', ''], + capture_output=True, text=True, timeout=30, + cwd=_REPO_ROOT, + ) + stdout_json = json.loads(result.stdout) + assert '--config' in stdout_json['transport']['args'] + + +# ====================================================================== +# 10. ACP Server — real spawn_agent_process (initialize + new_session) +# ====================================================================== + +@pytest.mark.skipif( + not os.path.isfile(_ACP_TEST_CONFIG), + reason='No agent config found', +) +class TestACPServerReal: + + @pytest.mark.asyncio + async def test_initialize_and_new_session(self): + """Spawn a real ACP server subprocess and test initialize + new_session.""" + from acp import spawn_agent_process, text_block + from acp.interfaces import Client + + class TestClient(Client): + def __init__(self): + self.updates = [] + + async def session_update(self, session_id, update, **kwargs): + self.updates.append(update) + + async def request_permission(self, options, session_id, tool_call, **kwargs): + allow = next( + (o for o in options if 'allow' in (getattr(o, 'kind', '') or '')), + None, + ) + if allow: + return {'outcome': {'outcome': 'selected', 'id': getattr(allow, 'option_id', 'allow_once')}} + return {'outcome': {'outcome': 'cancelled'}} + + client = TestClient() + async with spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp', + '--config', _ACP_TEST_CONFIG, + ) as (conn, proc): + init_resp = await conn.initialize(protocol_version=1) + assert init_resp.protocol_version == 1 + assert init_resp.agent_info is not None + assert init_resp.agent_info.name == 'ms-agent' + assert init_resp.agent_info.version == '0.1.0' + + caps = init_resp.agent_capabilities + assert caps is not None + + session = await conn.new_session(cwd='/tmp', mcp_servers=[]) + assert session.session_id + assert session.session_id.startswith('ses_') + + @pytest.mark.asyncio + async def test_list_sessions_after_create(self): + """After creating a session, list_sessions should return it.""" + from acp import spawn_agent_process + from acp.interfaces import Client + + class TestClient(Client): + async def session_update(self, session_id, update, **kwargs): + pass + async def request_permission(self, options, session_id, tool_call, **kwargs): + return {'outcome': {'outcome': 'cancelled'}} + + client = TestClient() + async with spawn_agent_process( + client, + sys.executable, + '-m', 'ms_agent.cli.cli', + 'acp', + '--config', _ACP_TEST_CONFIG, + ) as (conn, proc): + await conn.initialize(protocol_version=1) + session = await conn.new_session(cwd='/tmp', mcp_servers=[]) + + sessions = await conn.list_sessions() + session_ids = [s.session_id for s in sessions.sessions] + assert session.session_id in session_ids + + +# ====================================================================== +# 11. CollectorClient — real _CollectorClient object behavior +# ====================================================================== + +class TestCollectorClientReal: + + @pytest.mark.asyncio + async def test_collect_text_updates(self): + from ms_agent.acp.client import _CollectorClient + from types import SimpleNamespace + + client = _CollectorClient() + sid = 'ses_test' + + update1 = SimpleNamespace(session_update='agent_message_chunk', + content=SimpleNamespace(text='Hello ')) + update2 = SimpleNamespace(session_update='agent_message_chunk', + content=SimpleNamespace(text='World')) + + await client.session_update(sid, update1) + await client.session_update(sid, update2) + + assert client.get_output(sid) == 'Hello World' + + @pytest.mark.asyncio + async def test_ignores_non_text_updates(self): + from ms_agent.acp.client import _CollectorClient + from types import SimpleNamespace + + client = _CollectorClient() + sid = 'ses_test' + + update = SimpleNamespace(session_update='tool_call_start', content=None) + await client.session_update(sid, update) + assert client.get_output(sid) == '' + + @pytest.mark.asyncio + async def test_auto_approve_permission(self): + from ms_agent.acp.client import _CollectorClient + from types import SimpleNamespace + + client = _CollectorClient(permission_policy='auto_approve') + options = [ + SimpleNamespace(kind='allow_once', option_id='allow_once'), + SimpleNamespace(kind='deny_once', option_id='deny_once'), + ] + result = await client.request_permission(options, 'ses_x', None) + assert result.outcome.outcome == 'selected' + assert result.outcome.option_id == 'allow_once' + + def test_clear(self): + from ms_agent.acp.client import _CollectorClient + + client = _CollectorClient() + client.collected['ses_1'] = ['data'] + client.clear('ses_1') + assert client.get_output('ses_1') == '' + + +# ====================================================================== +# 12. HTTP Adapter — DummyConn queue behavior (no server start needed) +# ====================================================================== + +class TestHTTPAdapterComponents: + + @pytest.mark.asyncio + async def test_dummy_conn_queue(self): + from ms_agent.acp.http_adapter import _DummyConn + from types import SimpleNamespace + + conn = _DummyConn() + q = conn.get_queue('ses_1') + assert q.empty() + + update = SimpleNamespace(session_update='agent_message_chunk') + update.model_dump = lambda by_alias=False: {'session_update': 'agent_message_chunk'} + await conn.session_update('ses_1', update) + assert not q.empty() + data = await q.get() + assert data['session_update'] == 'agent_message_chunk' + + @pytest.mark.asyncio + async def test_dummy_conn_auto_approve(self): + from ms_agent.acp.http_adapter import _DummyConn + from types import SimpleNamespace + + conn = _DummyConn() + options = [ + SimpleNamespace(kind='allow_once', option_id='allow_once'), + ] + result = await conn.request_permission('ses_1', None, options) + assert result.outcome['outcome'] == 'selected' + + +if __name__ == '__main__': + pytest.main([__file__, '-v', '--tb=short', '-x']) diff --git a/tests/test_acp/test_acp_unit.py b/tests/test_acp/test_acp_unit.py new file mode 100644 index 000000000..f4838bc48 --- /dev/null +++ b/tests/test_acp/test_acp_unit.py @@ -0,0 +1,384 @@ +"""Unit tests for ACP components (no external processes needed).""" + +import asyncio +import pytest +from unittest.mock import AsyncMock, MagicMock, patch + +from ms_agent.acp.errors import ( + ACPError, + ConfigError, + LLMError, + MaxSessionsError, + RateLimitError, + SessionNotFoundError, + wrap_agent_error, +) +from ms_agent.acp.translator import ACPTranslator +from ms_agent.acp.session_store import ACPSessionStore +from ms_agent.llm.utils import Message + + +# ====================================================================== +# Error mapping tests +# ====================================================================== + +class TestErrorMapping: + + def test_session_not_found(self): + err = SessionNotFoundError('ses_abc') + assert err.code == -32001 + assert 'ses_abc' in str(err.data) + + def test_wrap_acp_error(self): + err = LLMError('timeout') + rpc_err = wrap_agent_error(err) + assert rpc_err.code == -32003 + + def test_wrap_file_not_found(self): + rpc_err = wrap_agent_error(FileNotFoundError('/path')) + assert rpc_err.code == -32002 + + def test_wrap_value_error(self): + rpc_err = wrap_agent_error(ValueError('bad param')) + assert rpc_err.code == -32602 + + def test_wrap_unknown_error(self): + rpc_err = wrap_agent_error(RuntimeError('unexpected')) + assert rpc_err.code == -32603 + + def test_max_sessions_error(self): + err = MaxSessionsError(8) + assert err.code == -32006 + assert err.data['max'] == 8 + + +# ====================================================================== +# Translator tests +# ====================================================================== + +class TestTranslator: + + def test_prompt_to_messages_text(self): + block = MagicMock() + block.type = 'text' + block.text = 'Hello world' + msgs = ACPTranslator.prompt_to_messages([block]) + assert len(msgs) == 1 + assert msgs[0].role == 'user' + assert msgs[0].content == 'Hello world' + + def test_prompt_to_messages_multiple_blocks(self): + b1 = MagicMock(type='text', text='Part A') + b2 = MagicMock(type='text', text='Part B') + msgs = ACPTranslator.prompt_to_messages([b1, b2]) + assert len(msgs) == 1 + assert 'Part A' in msgs[0].content + assert 'Part B' in msgs[0].content + + def test_prompt_appends_to_existing(self): + existing = [Message(role='system', content='You are helpful')] + block = MagicMock(type='text', text='Query') + result = ACPTranslator.prompt_to_messages([block], existing) + assert len(result) == 2 + assert result is existing + + def test_messages_to_updates_assistant_content(self): + t = ACPTranslator() + msgs = [ + Message(role='assistant', content='Hello'), + ] + updates = t.messages_to_updates(msgs) + assert len(updates) >= 1 + + def test_messages_to_updates_delta_tracking(self): + t = ACPTranslator() + msgs1 = [Message(role='assistant', content='He')] + u1 = t.messages_to_updates(msgs1) + assert len(u1) == 1 + + msgs2 = [Message(role='assistant', content='Hello')] + u2 = t.messages_to_updates(msgs2) + assert len(u2) == 1 + + def test_messages_to_updates_reasoning(self): + t = ACPTranslator() + msgs = [ + Message(role='assistant', content='', reasoning_content='thinking'), + ] + updates = t.messages_to_updates(msgs) + assert len(updates) >= 1 + + def test_messages_to_updates_tool_call(self): + t = ACPTranslator() + tc = { + 'id': 'tc_1', + 'type': 'function', + 'tool_name': 'web_search', + 'arguments': '{"query": "test"}', + } + msgs = [Message(role='assistant', content='', tool_calls=[tc])] + updates = t.messages_to_updates(msgs) + assert any(hasattr(u, 'tool_call_id') for u in updates) + + def test_messages_to_updates_tool_result(self): + t = ACPTranslator() + t._emitted_tool_ids.add('tc_1') + msgs = [ + Message(role='tool', content='search done', tool_call_id='tc_1'), + ] + updates = t.messages_to_updates(msgs) + assert len(updates) >= 1 + + def test_map_stop_reason_normal(self): + session = MagicMock() + session.cancelled = False + session.agent.runtime.round = 5 + session.agent.max_chat_round = 20 + reason = ACPTranslator.map_stop_reason(session) + assert reason == 'end_turn' + + def test_map_stop_reason_max_rounds(self): + session = MagicMock() + session.cancelled = False + session.agent.runtime.round = 21 + session.agent.max_chat_round = 20 + reason = ACPTranslator.map_stop_reason(session) + assert reason == 'max_turn_requests' + + def test_map_stop_reason_cancelled(self): + session = MagicMock() + session.cancelled = True + reason = ACPTranslator.map_stop_reason(session, cancelled=True) + assert reason == 'cancelled' + + def test_reset_turn(self): + t = ACPTranslator() + t._last_content_len = 100 + t._emitted_tool_ids.add('tc_1') + t.reset_turn() + assert t._last_content_len == 0 + assert len(t._emitted_tool_ids) == 0 + + def test_build_plan_update(self): + steps = [ + {'description': 'Search papers', 'status': 'in_progress', 'priority': 'high'}, + {'description': 'Analyze', 'status': 'pending'}, + ] + update = ACPTranslator.build_plan_update(steps) + assert hasattr(update, 'entries') + assert len(update.entries) == 2 + + +# ====================================================================== +# Session store tests +# ====================================================================== + +class TestSessionStore: + + def test_get_nonexistent_raises(self): + store = ACPSessionStore() + with pytest.raises(SessionNotFoundError): + store.get('nonexistent') + + @pytest.mark.asyncio + async def test_max_sessions_with_no_idle(self): + store = ACPSessionStore(max_sessions=1) + entry = MagicMock() + entry.is_running = True + entry.last_activity = 0 + store._sessions['ses_1'] = entry + with pytest.raises(MaxSessionsError): + await store.create( + config_path='/nonexistent/agent.yaml', + cwd='/tmp', + ) + + def test_list_sessions_empty(self): + store = ACPSessionStore() + assert store.list_sessions() == [] + + +# ====================================================================== +# ACP Client Manager tests +# ====================================================================== + +class TestACPClientManager: + + def test_list_agents_empty(self): + from ms_agent.acp.client import ACPClientManager + mgr = ACPClientManager() + assert mgr.list_agents() == [] + + def test_list_agents_from_config(self): + from ms_agent.acp.client import ACPClientManager + cfg = { + 'openclaw': {'command': 'openclaw', 'args': ['acp'], 'description': 'test'}, + 'claude': {'command': 'claude', 'args': [], 'description': 'test2'}, + } + mgr = ACPClientManager(cfg) + assert set(mgr.list_agents()) == {'openclaw', 'claude'} + + @pytest.mark.asyncio + async def test_call_unconfigured_agent(self): + from ms_agent.acp.client import ACPClientManager + mgr = ACPClientManager() + result = await mgr.call_agent('nonexistent', 'hello') + assert 'not configured' in result + + +# ====================================================================== +# ACP Agent Tool tests +# ====================================================================== + +class TestACPAgentTool: + + def test_from_config_none(self): + from ms_agent.tools.acp_agent_tool import ACPAgentTool + from omegaconf import OmegaConf + config = OmegaConf.create({'llm': {'model': 'test'}}) + assert ACPAgentTool.from_config(config) is None + + def test_from_config_with_agents(self): + from ms_agent.tools.acp_agent_tool import ACPAgentTool + from omegaconf import OmegaConf + config = OmegaConf.create({ + 'llm': {'model': 'test'}, + 'acp_agents': { + 'openclaw': { + 'command': 'openclaw', + 'args': ['acp'], + 'description': 'OpenClaw agent', + }, + }, + }) + tool = ACPAgentTool.from_config(config) + assert tool is not None + + @pytest.mark.asyncio + async def test_get_tools(self): + from ms_agent.tools.acp_agent_tool import ACPAgentTool + from omegaconf import OmegaConf + config = OmegaConf.create({ + 'llm': {'model': 'test'}, + 'acp_agents': { + 'openclaw': { + 'command': 'openclaw', + 'args': ['acp'], + 'description': 'OpenClaw coding agent', + }, + }, + }) + tool = ACPAgentTool(config, acp_agents_config={ + 'openclaw': { + 'command': 'openclaw', + 'args': ['acp'], + 'description': 'OpenClaw coding agent', + }, + }) + tools = await tool._get_tools_inner() + assert 'acp_openclaw' in tools + assert len(tools['acp_openclaw']) == 1 + assert tools['acp_openclaw'][0]['description'] == 'OpenClaw coding agent' + + +# ====================================================================== +# Config options tests +# ====================================================================== + +class TestConfigOptions: + + def test_build_config_options_with_model(self): + from ms_agent.acp.config import build_config_options + from omegaconf import OmegaConf + cfg = OmegaConf.create({'llm': {'model': 'qwen-max'}}) + opts = build_config_options(cfg) + assert opts is not None + assert len(opts) == 1 + assert opts[0].id == 'model' + + def test_build_config_options_without_model(self): + from ms_agent.acp.config import build_config_options + from omegaconf import OmegaConf + cfg = OmegaConf.create({}) + opts = build_config_options(cfg) + assert opts is None + + def test_apply_config_option_model(self): + from ms_agent.acp.config import apply_config_option + from omegaconf import OmegaConf + cfg = OmegaConf.create({'llm': {'model': 'qwen-max'}}) + result = apply_config_option(cfg, 'model', 'gpt-4o') + assert result is True + assert cfg.llm.model == 'gpt-4o' + + def test_apply_config_option_unknown(self): + from ms_agent.acp.config import apply_config_option + from omegaconf import OmegaConf + cfg = OmegaConf.create({'llm': {'model': 'qwen-max'}}) + result = apply_config_option(cfg, 'unknown_option', 'value') + assert result is False + + +# ====================================================================== +# Permission policy tests +# ====================================================================== + +class TestPermissionPolicy: + + def test_auto_approve_never_asks(self): + from ms_agent.acp.permissions import PermissionPolicy + p = PermissionPolicy('auto_approve') + assert p.should_ask('any_tool') is False + assert p.auto_decision('any_tool') == 'allow_once' + + def test_always_ask(self): + from ms_agent.acp.permissions import PermissionPolicy + p = PermissionPolicy('always_ask') + assert p.should_ask('web_search') is True + assert p.auto_decision('web_search') is None + + def test_remember_choice(self): + from ms_agent.acp.permissions import PermissionPolicy + p = PermissionPolicy('remember_choice') + assert p.should_ask('web_search') is True + p.record_choice('web_search', True) + assert p.should_ask('web_search') is False + assert 'allow' in p.auto_decision('web_search') + + def test_remember_deny(self): + from ms_agent.acp.permissions import PermissionPolicy + p = PermissionPolicy('remember_choice') + p.record_choice('dangerous_tool', False) + assert 'deny' in p.auto_decision('dangerous_tool') + + def test_reset(self): + from ms_agent.acp.permissions import PermissionPolicy + p = PermissionPolicy('remember_choice') + p.record_choice('tool_a', True) + p.reset() + assert p.should_ask('tool_a') is True + + +# ====================================================================== +# Registry tests +# ====================================================================== + +class TestRegistry: + + def test_generate_manifest(self): + from ms_agent.acp.registry import generate_agent_manifest + manifest = generate_agent_manifest( + config_path='/path/to/config.yaml', + output_path=None, + ) + assert manifest['name'] == 'ms-agent' + assert manifest['protocol'] == 'acp' + assert manifest['protocolVersion'] == 1 + assert manifest['transport']['type'] == 'stdio' + assert '--config' in manifest['transport']['args'] + assert '/path/to/config.yaml' in manifest['transport']['args'] + + def test_generate_manifest_without_config(self): + from ms_agent.acp.registry import generate_agent_manifest + manifest = generate_agent_manifest(output_path=None) + assert manifest['transport']['args'] == ['acp'] diff --git a/tests/test_acp/test_openai_config.yaml b/tests/test_acp/test_openai_config.yaml new file mode 100644 index 000000000..0f0f21342 --- /dev/null +++ b/tests/test_acp/test_openai_config.yaml @@ -0,0 +1,18 @@ +llm: + service: openai + model: qwen-plus + openai_base_url: https://dashscope.aliyuncs.com/compatible-mode/v1 + +generation_config: + temperature: 0.1 + stream: true + +prompt: + system: | + You are a concise assistant. Answer questions briefly and directly. + +max_chat_round: 3 + +callbacks: [] + +tools: